diff --git a/src/codegen/sdk/core/detached_symbols/argument.py b/src/codegen/sdk/core/detached_symbols/argument.py index beacbc4bf..c6d771397 100644 --- a/src/codegen/sdk/core/detached_symbols/argument.py +++ b/src/codegen/sdk/core/detached_symbols/argument.py @@ -52,6 +52,13 @@ def __init__(self, node: TSNode, positional_idx: int, parent: FunctionCall) -> N self._name_node = self._parse_expression(name_node, default=Name) self._value_node = self._parse_expression(_value_node) + def __repr__(self) -> str: + keyword = f"keyword={self.name}, " if self.name else "" + value = f"value='{self.value}', " if self.value else "" + type = f"type={self.type}" if self.type else "" + + return f"Argument({keyword}{value}{type})" + @noapidoc @classmethod def from_argument_list(cls, node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: FunctionCall) -> MultiExpression[Parent, Argument]: diff --git a/src/codegen/sdk/core/detached_symbols/function_call.py b/src/codegen/sdk/core/detached_symbols/function_call.py index 3f30582c0..4646946b0 100644 --- a/src/codegen/sdk/core/detached_symbols/function_call.py +++ b/src/codegen/sdk/core/detached_symbols/function_call.py @@ -62,6 +62,28 @@ def __init__(self, node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: args = [Argument(x, i, self) for i, x in enumerate(arg_list_node.named_children) if x.type != "comment"] self._arg_list = Collection(arg_list_node, self.file_node_id, self.G, self, children=args) + def __repr__(self) -> str: + """Custom string representation showing the function call chain structure. + + Format: FunctionCall(name=current, pred=pred_name, succ=succ_name, base=base_name) + + It will only print out predecessor, successor, and base that are of type FunctionCall. If it's a property, it will not be logged + """ + # Helper to safely get name + + # Get names for each part + parts = [f"name='{self.name}'"] + + if self.predecessor and isinstance(self.predecessor, FunctionCall): + parts.append(f"predecessor=FunctionCall(name='{self.predecessor.name}')") + + if self.successor and isinstance(self.successor, FunctionCall): + parts.append(f"successor=FunctionCall(name='{self.successor.name}')") + + parts.append(f"filepath='{self.file.filepath}'") + + return f"FunctionCall({', '.join(parts)})" + @classmethod def from_usage(cls, node: Editable[Parent], parent: Parent | None = None) -> Self | None: """Creates a FunctionCall object from an Editable instance that represents a function call. @@ -210,9 +232,33 @@ def predecessor(self) -> FunctionCall[Parent] | None: or if the predecessor is not a function call. """ # Recursively travel down the tree to find the previous function call (child nodes are previous calls) - return self.call_chain[-2] if len(self.call_chain) > 1 else None + name = self.get_name() + while name: + if isinstance(name, FunctionCall): + return name + elif isinstance(name, ChainedAttribute): + name = name.object + else: + break + return None + + @property + @reader + def successor(self) -> FunctionCall[Parent] | None: + """Returns the next function call in a function call chain. + + Returns the next function call in a function call chain. This method is useful for traversing function call chains + to analyze or modify sequences of chained function calls. + + Returns: + FunctionCall[Parent] | None: The next function call in the chain, or None if there is no successor + or if the successor is not a function call. + """ + # this will avoid parent function calls in tree-sitter that are NOT part of the chained calls + if not isinstance(self.parent, ChainedAttribute): + return None - # TODO: also define a successor? + return self.parent_of_type(FunctionCall) @property @noapidoc @@ -581,6 +627,26 @@ def function_calls(self) -> list[FunctionCall]: # calls.append(call) return sort_editables(calls, dedupe=False) + @property + @reader + def attribute_chain(self) -> list[FunctionCall | Name]: + """Returns a list of elements in the chainedAttribute that the function call belongs in. + + Breaks down chained expressions into individual components in order of appearance. + For example: `a.b.c().d` -> [Name("a"), Name("b"), FunctionCall("c"), Name("d")] + + Returns: + list[FunctionCall | Name]: List of Name nodes (property access) and FunctionCall nodes (method calls) + """ + if isinstance(self.get_name(), ChainedAttribute): # child is chainedAttribute. MEANING that this is likely in the middle or the last function call of a chained function call chain. + return self.get_name().attribute_chain + elif isinstance( + self.parent, ChainedAttribute + ): # does not have child chainedAttribute, but parent is chainedAttribute. MEANING that this is likely the TOP function call of a chained function call chain. + return self.parent.attribute_chain + else: # this is a standalone function call + return [self] + @property @noapidoc def descendant_symbols(self) -> list[Importable]: @@ -603,24 +669,35 @@ def register_api_call(self, url: str): @property @reader def call_chain(self) -> list[FunctionCall]: - """Returns a list of all function calls in this function call chain, including this call. Does not include calls made after this one.""" + """Returns a list of all function calls in this function call chain, including this call. Does not include calls made after this one.""" ret = [] - name = self.get_name() - while name: - if isinstance(name, FunctionCall): - ret.extend(name.call_chain) - break - elif isinstance(name, ChainedAttribute): - name = name.object - else: - break + + # backward traversal + curr = self + pred = curr.predecessor + while pred is not None and isinstance(pred, FunctionCall): + ret.insert(0, pred) + pred = pred.predecessor + ret.append(self) + + # forward traversal + curr = self + succ = curr.successor + while succ is not None and isinstance(succ, FunctionCall): + ret.append(succ) + succ = succ.successor + return ret @property @reader def base(self) -> Editable | None: - """Returns the base object of this function call chain.""" + """Returns the base object of this function call chain. + + Args: + Editable | None: The base object of this function call chain. + """ name = self.get_name() while isinstance(name, ChainedAttribute): if isinstance(name.object, FunctionCall): diff --git a/src/codegen/sdk/core/expressions/chained_attribute.py b/src/codegen/sdk/core/expressions/chained_attribute.py index 66cc06705..45ee7a90f 100644 --- a/src/codegen/sdk/core/expressions/chained_attribute.py +++ b/src/codegen/sdk/core/expressions/chained_attribute.py @@ -15,9 +15,11 @@ from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from codegen.sdk.core.detached_symbols.function_call import FunctionCall from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.importable import Importable + Object = TypeVar("Object", bound="Chainable") Attribute = TypeVar("Attribute", bound="Resolvable") Parent = TypeVar("Parent", bound="Expression") @@ -74,6 +76,49 @@ def attribute(self) -> Attribute: """ return self._attribute + @property + @reader + def attribute_chain(self) -> list["FunctionCall | Name"]: + """Returns a list of elements in a chained attribute expression. + + Breaks down chained expressions into individual components in order of appearance. + For example: `a.b.c().d` -> [Name("a"), Name("b"), FunctionCall("c"), Name("d")] + + Returns: + list[FunctionCall | Name]: List of Name nodes (property access) and FunctionCall nodes (method calls) + """ + from codegen.sdk.core.detached_symbols.function_call import FunctionCall + + ret = [] + curr = self + + # Traverse backwards in code (children of tree node) + while isinstance(curr, ChainedAttribute): + curr = curr.object + + if isinstance(curr, FunctionCall): + ret.insert(0, curr) + curr = curr.get_name() + elif isinstance(curr, ChainedAttribute): + ret.insert(0, curr.attribute) + + # This means that we have reached the base of the chain and the first item was an attribute (i.e a.b.c.func()) + if isinstance(curr, Name) and not isinstance(curr.parent, FunctionCall): + ret.insert(0, curr) + + curr = self + + # Traversing forward in code (parents of tree node). Will add the current node as well + while isinstance(curr, ChainedAttribute) or isinstance(curr, FunctionCall): + if isinstance(curr, FunctionCall): + ret.append(curr) + elif isinstance(curr, ChainedAttribute) and not isinstance(curr.parent, FunctionCall): + ret.append(curr.attribute) + + curr = curr.parent + + return ret + @property def object(self) -> Object: """Returns the object that contains the attribute being looked up. diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index 625828b74..416b4285c 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -75,6 +75,10 @@ def _is_empty_container(text: str) -> bool: "resolved_types", "valid_symbol_names", "valid_import_names", + "predecessor", + "successor", + "base", + "call_chain", "code_block", "parent_statement", "symbol_usages", diff --git a/src/codegen/sdk/core/symbol_group.py b/src/codegen/sdk/core/symbol_group.py index 24d75c15d..6f548cd60 100644 --- a/src/codegen/sdk/core/symbol_group.py +++ b/src/codegen/sdk/core/symbol_group.py @@ -37,6 +37,9 @@ def __init__(self, file_node_id: NodeId, G: CodebaseGraph, parent: Parent, node: node = children[0].ts_node super().__init__(node, file_node_id, G, parent) + def __repr__(self) -> str: + return f"Collection({self.symbols})" if self.symbols is not None else super().__repr__() + def _init_children(self): ... @repr_func # HACK diff --git a/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_call.py b/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_call.py index d8a735e8d..254756cb7 100644 --- a/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_call.py +++ b/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_call.py @@ -502,8 +502,8 @@ def baz(): # Check call chain assert c.call_chain == [a, b, c] - assert b.call_chain == [a, b] - assert a.call_chain == [a] + assert b.call_chain == [a, b, c] + assert a.call_chain == [a, b, c] # Check base assert c.base == a.get_name() @@ -530,8 +530,8 @@ def baz(): # Check call chain assert c.call_chain == [a, b, c] - assert b.call_chain == [a, b] - assert a.call_chain == [a] + assert b.call_chain == [a, b, c] + assert a.call_chain == [a, b, c] # Check base assert c.base.source == "x" @@ -539,6 +539,94 @@ def baz(): assert a.base.source == "x" +def test_function_call_chain_nested(tmpdir) -> None: + # language=python + content = """ +def foo(): + # Nested function calls - each call should be independent + a(b(c())) +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + foo = file.get_function("foo") + calls = foo.function_calls + assert len(calls) == 3 + a = calls[0] + b = calls[1] + c = calls[2] + + # Each call should be independent - no predecessors + assert a.predecessor is None + assert b.predecessor is None + assert c.predecessor is None + + # No successors since they're nested, not chained + assert a.successor is None + assert b.successor is None + assert c.successor is None + + # Call chain for each should only include itself + assert a.call_chain == [a] + assert b.call_chain == [b] + assert c.call_chain == [c] + + # Verify source strings are correct + assert a.source == "a(b(c()))" + assert b.source == "b(c())" + assert c.source == "c()" + + +def test_function_call_chain_successor(tmpdir) -> None: + # language=python + content = """ +def foo(): + a().b().c() + +def bat(): + x.y.z.func() + +def baz(): + x.a().y.b().z.c() +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + + # Check foo + foo = file.get_function("foo") + calls = foo.function_calls + assert len(calls) == 3 + c = calls[0] + b = calls[1] + a = calls[2] + + # Check successors + assert a.successor == b + assert b.successor == c + assert c.successor is None + + # Check bat + bat = file.get_function("bat") + calls = bat.function_calls + assert len(calls) == 1 + func = calls[0] + + # No successor since it's a single function call + assert func.successor is None + + # Check baz + baz = file.get_function("baz") + calls = baz.function_calls + assert len(calls) == 3 + c = calls[0] + b = calls[1] + a = calls[2] + + # Check successors + assert a.successor == b + assert b.successor == c + assert c.successor is None + + def test_function_call_chain_hard(tmpdir) -> None: # language=python content = """ diff --git a/tests/unit/codegen/sdk/python/expressions/test_chained_attribute_attribute_chain.py b/tests/unit/codegen/sdk/python/expressions/test_chained_attribute_attribute_chain.py new file mode 100644 index 000000000..54715b949 --- /dev/null +++ b/tests/unit/codegen/sdk/python/expressions/test_chained_attribute_attribute_chain.py @@ -0,0 +1,124 @@ +from codegen.sdk.codebase.factory.get_session import get_codebase_session + + +def test_attribute_chain_query_builder(tmpdir) -> None: + # language=python + content = """ +def query(): + # Test chained method calls with function at start + QueryBuilder().select("name", "age").from_table("users").where("age > 18").order_by("name") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + query = file.get_function("query") + calls = query.function_calls + assert len(calls) == 5 + order_by = calls[0] # Last call in chain + where = calls[1] + from_table = calls[2] + select = calls[3] + query_builder = calls[4] # First call in chain + + # Test attribute chain from different positions + # From first call (QueryBuilder()) + chain = query_builder.attribute_chain + assert len(chain) == 5 + assert chain[0] == query_builder + assert chain[1] == select + assert chain[2] == from_table + assert chain[3] == where + assert chain[4] == order_by + + # From middle call (from_table()) + chain = from_table.attribute_chain + assert len(chain) == 5 + assert chain[0] == query_builder + assert chain[1] == select + assert chain[2] == from_table + assert chain[3] == where + assert chain[4] == order_by + + # From last call (order_by()) + chain = order_by.attribute_chain + assert len(chain) == 5 + assert chain[0] == query_builder + assert chain[1] == select + assert chain[2] == from_table + assert chain[3] == where + assert chain[4] == order_by + + +def test_attribute_chain_mixed_properties(tmpdir) -> None: + # language=python + content = """ +def query(): + # Test mix of properties and function calls + QueryBuilder().a.select("name", "age").from_table("users").where("age > 18").b.order_by("name").c +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + query = file.get_function("query") + calls = query.function_calls + + # Get function calls in order + order_by = calls[0] # Last function call + where = calls[1] + from_table = calls[2] + select = calls[3] + query_builder = calls[4] # First function call + + # Test from first call + chain = query_builder.attribute_chain + assert len(chain) == 8 # 5 function calls + 3 properties (a, b, c) + assert chain[0] == query_builder + assert chain[1].source == "a" # Property + assert chain[2] == select + assert chain[3] == from_table + assert chain[4] == where + assert chain[5].source == "b" # Property + assert chain[6] == order_by + assert chain[7].source == "c" # Property + + +def test_attribute_chain_only_properties(tmpdir) -> None: + # language=python + content = """ +def test(): + # Test chain with only properties + a.b.c.func() +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + test = file.get_function("test") + calls = test.function_calls + assert len(calls) == 1 + func = calls[0] + + chain = func.attribute_chain + assert len(chain) == 4 + assert chain[0].source == "a" + assert chain[1].source == "b" + assert chain[2].source == "c" + assert chain[3] == func + + +def test_attribute_chain_nested_calls(tmpdir) -> None: + # language=python + content = """ +def test(): + # Test nested function calls (not chained) + a(b(c())) +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + test = file.get_function("test") + calls = test.function_calls + assert len(calls) == 3 + a = calls[0] + b = calls[1] + c = calls[2] + + # Each call should have its own single-element chain + assert a.attribute_chain == [a] + assert b.attribute_chain == [b] + assert c.attribute_chain == [c] diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_chained_attribute_attribute_chain.py b/tests/unit/codegen/sdk/typescript/expressions/test_chained_attribute_attribute_chain.py new file mode 100644 index 000000000..b4c7b36b5 --- /dev/null +++ b/tests/unit/codegen/sdk/typescript/expressions/test_chained_attribute_attribute_chain.py @@ -0,0 +1,202 @@ +from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.sdk.enums import ProgrammingLanguage + + +def test_attribute_chain_query_builder(tmpdir) -> None: + # language=typescript + content = """ +function query() { + // Test chained method calls with function at start + QueryBuilder().select("name", "age").fromTable("users").where("age > 18").orderBy("name"); +} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + query = file.get_function("query") + calls = query.function_calls + assert len(calls) == 5 + order_by = calls[0] # Last call in chain + where = calls[1] + from_table = calls[2] + select = calls[3] + query_builder = calls[4] # First call in chain + + # Test attribute chain from different positions + # From first call (QueryBuilder()) + chain = query_builder.attribute_chain + assert len(chain) == 5 + assert chain[0] == query_builder + assert chain[1] == select + assert chain[2] == from_table + assert chain[3] == where + assert chain[4] == order_by + + # From middle call (from_table()) + chain = from_table.attribute_chain + assert len(chain) == 5 + assert chain[0] == query_builder + assert chain[1] == select + assert chain[2] == from_table + assert chain[3] == where + assert chain[4] == order_by + + # From last call (order_by()) + chain = order_by.attribute_chain + assert len(chain) == 5 + assert chain[0] == query_builder + assert chain[1] == select + assert chain[2] == from_table + assert chain[3] == where + assert chain[4] == order_by + + +def test_attribute_chain_mixed_properties(tmpdir) -> None: + # language=typescript + content = """ +function query() { + // Test mix of properties and function calls + QueryBuilder().a.select("name", "age").fromTable("users").where("age > 18").b.orderBy("name").c; +} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + query = file.get_function("query") + calls = query.function_calls + + # Get function calls in order + order_by = calls[0] # Last function call + where = calls[1] + from_table = calls[2] + select = calls[3] + query_builder = calls[4] # First function call + + # Test from first call + chain = query_builder.attribute_chain + assert len(chain) == 8 # 5 function calls + 3 properties (a, b, c) + assert chain[0] == query_builder + assert chain[1].source == "a" # Property + assert chain[2] == select + assert chain[3] == from_table + assert chain[4] == where + assert chain[5].source == "b" # Property + assert chain[6] == order_by + assert chain[7].source == "c" # Property + + +def test_attribute_chain_only_properties(tmpdir) -> None: + # language=typescript + content = """ +function test() { + // Test chain with only properties + a.b.c.func(); +} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + test = file.get_function("test") + calls = test.function_calls + assert len(calls) == 1 + func = calls[0] + + chain = func.attribute_chain + assert len(chain) == 4 + assert chain[0].source == "a" + assert chain[1].source == "b" + assert chain[2].source == "c" + assert chain[3] == func + + +def test_attribute_chain_nested_calls(tmpdir) -> None: + # language=typescript + content = """ +function test() { + // Test nested function calls (not chained) + a(b(c())); +} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + test = file.get_function("test") + calls = test.function_calls + assert len(calls) == 3 + a = calls[0] + b = calls[1] + c = calls[2] + + # Each call should have its own single-element chain + assert a.attribute_chain == [a] + assert b.attribute_chain == [b] + assert c.attribute_chain == [c] + + +def test_attribute_chain_promise_then(tmpdir) -> None: + # language=typescript + content = """ +function test() { + // Test Promise chain with multiple then calls + fetch("https://api.example.com/data") + .then(response => response.json()) + .then(data => processData(data)) + .then(result => console.log(result)) + .catch(error => handleError(error)); +} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + test = file.get_function("test") + calls = test.function_calls + + # Get function calls in order (last to first) + catch_call = calls[0] + then3 = calls[1] # console.log + then2 = calls[2] # processData + then1 = calls[3] # response.json + fetch = calls[4] # First call in chain + + # Test attribute chain from fetch + chain = fetch.attribute_chain + assert len(chain) == 5 + assert chain[0] == fetch + assert chain[1] == then1 + assert chain[2] == then2 + assert chain[3] == then3 + assert chain[4] == catch_call + + # Test from middle of chain + chain = then2.attribute_chain + assert len(chain) == 5 + assert chain[0] == fetch + assert chain[1] == then1 + assert chain[2] == then2 + assert chain[3] == then3 + assert chain[4] == catch_call + + +def test_attribute_chain_async_await_promise(tmpdir) -> None: + # language=typescript + content = """ +async function test() { + // Test Promise chain with mix of async/await and then + const result = await axios.get("/api/data") + .then(response => response.data) + .then(data => transform(data)); +} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + test = file.get_function("test") + calls = test.function_calls + + # Get function calls in order + then2 = calls[0] # transform + then1 = calls[1] # response.data + get = calls[2] # get + axios = calls[3] # axios + + # Test attribute chain + chain = get.attribute_chain + assert len(chain) == 4 + assert chain[0].source == "axios" + assert chain[1] == get + assert chain[2] == then1 + assert chain[3] == then2