Skip to content

Commit 2271be8

Browse files
refactor(sdk): merge dependency methods and add depth parameter
- Make dependencies a proxy property with optional depth parameter - Default to all dependencies (depth=None) - Support direct dependencies via depth=1 - Maintain backward compatibility - Update move_symbol_to_file to use direct dependencies Co-Authored-By: [email protected] <[email protected]>
1 parent 41ef6b6 commit 2271be8

File tree

3 files changed

+209
-9
lines changed

3 files changed

+209
-9
lines changed

src/codegen/sdk/core/interfaces/importable.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import TYPE_CHECKING, Generic, Self, TypeVar, Union
2+
from typing import TYPE_CHECKING, Generic, Optional, Self, TypeVar, Union
33

44
from tree_sitter import Node as TSNode
55

@@ -45,20 +45,41 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", pa
4545
def dependencies(self) -> list[Union["Symbol", "Import"]]:
4646
"""Returns a list of symbols that this symbol depends on.
4747
48-
Returns a list of symbols (including imports) that this symbol directly depends on.
49-
The returned list is sorted by file location for consistent ordering.
48+
This property returns all dependencies (direct and indirect).
49+
For finer control over dependency depth, use get_dependencies_to_depth().
5050
5151
Returns:
52-
list[Union[Symbol, Import]]: A list of symbols and imports that this symbol directly depends on,
52+
list[Union[Symbol, Import]]: A list of symbols and imports that this symbol depends on,
5353
sorted by file location.
5454
"""
55-
return self.get_dependencies(UsageType.DIRECT)
55+
return self.get_dependencies_to_depth()
56+
57+
@reader(cache=False)
58+
def get_dependencies_to_depth(self, depth: Optional[int] = None) -> list[Union["Symbol", "Import"]]:
59+
"""Returns a list of symbols that this symbol depends on up to a specified depth.
60+
61+
Args:
62+
depth: Optional[int], maximum depth to traverse in the dependency graph.
63+
None means unlimited depth (all dependencies).
64+
1 means direct dependencies only.
65+
Default is None (all dependencies).
66+
67+
Returns:
68+
list[Union[Symbol, Import]]: A list of symbols and imports that this symbol depends on,
69+
sorted by file location.
70+
"""
71+
if depth == 1:
72+
return self.get_dependencies(UsageType.DIRECT)
73+
return self._get_all_dependencies(depth)
5674

5775
@reader(cache=False)
5876
@noapidoc
5977
def get_dependencies(self, usage_types: UsageType) -> list[Union["Symbol", "Import"]]:
60-
"""Returns Symbols and Importsthat this symbol depends on.
61-
78+
"""Internal method for getting dependencies by usage type.
79+
80+
This is kept for backward compatibility and internal use.
81+
New code should use the dependencies property or get_dependencies_to_depth method.
82+
6283
Opposite of `usages`
6384
"""
6485
avoid = set(self.descendant_symbols)
@@ -113,6 +134,34 @@ def _remove_internal_edges(self, edge_type: EdgeType | None = None) -> None:
113134
for v in self.G.successors(self.node_id, edge_type=edge_type):
114135
self.G.remove_edge(self.node_id, v.node_id, edge_type=edge_type)
115136

137+
@reader(cache=False)
138+
@noapidoc
139+
def _get_all_dependencies(self, max_depth: Optional[int] = None) -> list[Union["Symbol", "Import"]]:
140+
"""Gets all dependencies up to specified depth.
141+
142+
Internal implementation of depth-based dependency traversal.
143+
"""
144+
dependency_map: dict[Self, list[Union["Symbol", "Import"]]] = {}
145+
current_depth = 1
146+
147+
def _collect_deps(symbol: Self, depth: int) -> None:
148+
if depth > max_depth if max_depth is not None else False:
149+
return
150+
151+
direct_deps = symbol.get_dependencies(UsageType.DIRECT)
152+
if symbol not in dependency_map:
153+
dependency_map[symbol] = direct_deps
154+
155+
for dep in direct_deps:
156+
if isinstance(dep, Symbol):
157+
_collect_deps(dep, depth + 1)
158+
159+
_collect_deps(self, current_depth)
160+
all_deps = []
161+
for deps in dependency_map.values():
162+
all_deps.extend(deps)
163+
return sort_editables(list(set(all_deps)), by_file=True)
164+
116165
@property
117166
@noapidoc
118167
def descendant_symbols(self) -> list[Self]:

src/codegen/sdk/core/symbol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _move_to_file(
312312

313313
if include_dependencies:
314314
# =====[ Move over dependencies recursively ]=====
315-
for dep in self.dependencies:
315+
for dep in self.get_dependencies_to_depth(depth=1): # Only get direct dependencies
316316
if dep in encountered_symbols:
317317
continue
318318

@@ -333,7 +333,7 @@ def _move_to_file(
333333
else:
334334
file.add_import_from_import_string(dep.source)
335335
else:
336-
for dep in self.dependencies:
336+
for dep in self.get_dependencies_to_depth(depth=1): # Only get direct dependencies
337337
# =====[ Symbols - add back edge ]=====
338338
if isinstance(dep, Symbol) and dep.is_top_level:
339339
file.add_symbol_import(symbol=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from codegen.sdk.codebase.factory.get_session import get_codebase_session
2+
from codegen.sdk.core.dataclasses.usage import UsageType
3+
from codegen.sdk.enums import ProgrammingLanguage
4+
5+
6+
def test_dependencies_with_depth(tmpdir):
7+
"""Test that dependencies property and get_dependencies_to_depth respect depth parameter."""
8+
py_code = """
9+
class A:
10+
def method_a(self):
11+
pass
12+
13+
class B(A):
14+
def method_b(self):
15+
self.method_a()
16+
17+
class C(B):
18+
def method_c(self):
19+
self.method_b()
20+
21+
def use_c():
22+
return C()
23+
"""
24+
with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.PYTHON, files={"test.py": py_code}) as G:
25+
use_c_func = G.get_function("use_c")
26+
c_class = G.get_class("C")
27+
b_class = G.get_class("B")
28+
a_class = G.get_class("A")
29+
30+
# Test direct dependencies (depth=1)
31+
direct_deps = c_class.get_dependencies_to_depth(depth=1)
32+
assert len(direct_deps) == 1
33+
assert direct_deps[0] == b_class
34+
35+
# Test two levels of dependencies (depth=2)
36+
two_level_deps = c_class.get_dependencies_to_depth(depth=2)
37+
assert len(two_level_deps) == 2
38+
assert b_class in two_level_deps
39+
assert a_class in two_level_deps
40+
41+
# Test unlimited depth (depth=None)
42+
all_deps = c_class.get_dependencies_to_depth()
43+
assert len(all_deps) == 2
44+
assert b_class in all_deps
45+
assert a_class in all_deps
46+
47+
# Test that default dependencies property returns all dependencies
48+
default_deps = c_class.dependencies
49+
assert set(default_deps) == set(all_deps)
50+
51+
52+
def test_dependencies_with_imports(tmpdir):
53+
"""Test that dependencies property and get_dependencies_to_depth handle imports correctly."""
54+
py_code = """
55+
from typing import Optional
56+
57+
class MyClass:
58+
def __init__(self, value: Optional[str] = None):
59+
self.value = value
60+
"""
61+
with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.PYTHON, files={"test.py": py_code}) as G:
62+
my_class = G.get_class("MyClass")
63+
64+
# Test that imports are included in dependencies
65+
deps = my_class.get_dependencies_to_depth(depth=1)
66+
assert len(deps) == 2 # Optional and str
67+
assert any(dep.name == "Optional" for dep in deps)
68+
assert any(dep.name == "str" for dep in deps)
69+
70+
# Test that default dependencies property includes imports
71+
all_deps = my_class.dependencies
72+
assert len(all_deps) == 2
73+
assert any(dep.name == "Optional" for dep in all_deps)
74+
assert any(dep.name == "str" for dep in all_deps)
75+
76+
77+
def test_backward_compatibility(tmpdir):
78+
"""Test that get_dependencies method still works for backward compatibility."""
79+
py_code = """
80+
from typing import Optional
81+
82+
class MyClass:
83+
def __init__(self, value: Optional[str] = None):
84+
self.value = value
85+
"""
86+
with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.PYTHON, files={"test.py": py_code}) as G:
87+
my_class = G.get_class("MyClass")
88+
89+
# Test that get_dependencies with DIRECT usage type works
90+
direct_deps = my_class.get_dependencies(UsageType.DIRECT)
91+
assert len(direct_deps) == 2
92+
assert any(dep.name == "Optional" for dep in direct_deps)
93+
assert any(dep.name == "str" for dep in direct_deps)
94+
95+
# Test that get_dependencies with CHAINED usage type works
96+
chained_deps = my_class.get_dependencies(UsageType.CHAINED)
97+
assert len(chained_deps) == 0 # No chained dependencies in this example
98+
99+
# Test that get_dependencies with combined usage types works
100+
all_deps = my_class.get_dependencies(UsageType.DIRECT | UsageType.CHAINED)
101+
assert len(all_deps) == 2
102+
assert any(dep.name == "Optional" for dep in all_deps)
103+
assert any(dep.name == "str" for dep in all_deps)
104+
105+
106+
def test_zero_depth(tmpdir):
107+
"""Test that depth=0 returns empty list."""
108+
py_code = """
109+
class A:
110+
pass
111+
112+
class B(A):
113+
pass
114+
"""
115+
with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.PYTHON, files={"test.py": py_code}) as G:
116+
b_class = G.get_class("B")
117+
118+
# Test that depth=0 returns empty list
119+
deps = b_class.get_dependencies_to_depth(depth=0)
120+
assert len(deps) == 0
121+
122+
123+
def test_cyclic_dependencies(tmpdir):
124+
"""Test that dependencies property and get_dependencies_to_depth handle cyclic dependencies."""
125+
py_code = """
126+
class A:
127+
def method_a(self):
128+
return B()
129+
130+
class B:
131+
def method_b(self):
132+
return A()
133+
"""
134+
with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.PYTHON, files={"test.py": py_code}) as G:
135+
a_class = G.get_class("A")
136+
b_class = G.get_class("B")
137+
138+
# Test direct dependencies
139+
a_direct_deps = a_class.get_dependencies_to_depth(depth=1)
140+
assert len(a_direct_deps) == 1
141+
assert a_direct_deps[0] == b_class
142+
143+
# Test that cyclic dependencies are handled properly
144+
a_all_deps = a_class.get_dependencies_to_depth()
145+
assert len(a_all_deps) == 1
146+
assert b_class in a_all_deps
147+
148+
# Test B's dependencies
149+
b_direct_deps = b_class.get_dependencies_to_depth(depth=1)
150+
assert len(b_direct_deps) == 1
151+
assert b_direct_deps[0] == a_class

0 commit comments

Comments
 (0)