Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/codegen/sdk/core/detached_symbols/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def __init__(self, node: TSNode, positional_idx: int, parent: FunctionCall) -> N
self._name_node = self._parse_expression(name_node, default=Name)
self._value_node = self._parse_expression(_value_node)

def __repr__(self) -> str:
keyword = f"keyword={self.name}, " if self.name else ""
value = f"value='{self.value}', " if self.value else ""
type = f"type={self.type}" if self.type else ""

return f"Argument({keyword}{value}{type})"

@noapidoc
@classmethod
def from_argument_list(cls, node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: FunctionCall) -> MultiExpression[Parent, Argument]:
Expand Down
103 changes: 90 additions & 13 deletions src/codegen/sdk/core/detached_symbols/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def __init__(self, node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent:
args = [Argument(x, i, self) for i, x in enumerate(arg_list_node.named_children) if x.type != "comment"]
self._arg_list = Collection(arg_list_node, self.file_node_id, self.G, self, children=args)

def __repr__(self) -> str:
"""Custom string representation showing the function call chain structure.

Format: FunctionCall(name=current, pred=pred_name, succ=succ_name, base=base_name)

It will only print out predecessor, successor, and base that are of type FunctionCall. If it's a property, it will not be logged
"""
# Helper to safely get name

# Get names for each part
parts = [f"name='{self.name}'"]

if self.predecessor and isinstance(self.predecessor, FunctionCall):
parts.append(f"predecessor=FunctionCall(name='{self.predecessor.name}')")

if self.successor and isinstance(self.successor, FunctionCall):
parts.append(f"successor=FunctionCall(name='{self.successor.name}')")

parts.append(f"filepath='{self.file.filepath}'")

return f"FunctionCall({', '.join(parts)})"

@classmethod
def from_usage(cls, node: Editable[Parent], parent: Parent | None = None) -> Self | None:
"""Creates a FunctionCall object from an Editable instance that represents a function call.
Expand Down Expand Up @@ -210,9 +232,33 @@ def predecessor(self) -> FunctionCall[Parent] | None:
or if the predecessor is not a function call.
"""
# Recursively travel down the tree to find the previous function call (child nodes are previous calls)
return self.call_chain[-2] if len(self.call_chain) > 1 else None
name = self.get_name()
while name:
if isinstance(name, FunctionCall):
return name
elif isinstance(name, ChainedAttribute):
name = name.object
else:
break
return None

@property
@reader
def successor(self) -> FunctionCall[Parent] | None:
"""Returns the next function call in a function call chain.

Returns the next function call in a function call chain. This method is useful for traversing function call chains
to analyze or modify sequences of chained function calls.

Returns:
FunctionCall[Parent] | None: The next function call in the chain, or None if there is no successor
or if the successor is not a function call.
"""
# this will avoid parent function calls in tree-sitter that are NOT part of the chained calls
if not isinstance(self.parent, ChainedAttribute):
return None

# TODO: also define a successor?
return self.parent_of_type(FunctionCall)

@property
@noapidoc
Expand Down Expand Up @@ -581,6 +627,26 @@ def function_calls(self) -> list[FunctionCall]:
# calls.append(call)
return sort_editables(calls, dedupe=False)

@property
@reader
def attribute_chain(self) -> list[FunctionCall | Name]:
"""Returns a list of elements in the chainedAttribute that the function call belongs in.

Breaks down chained expressions into individual components in order of appearance.
For example: `a.b.c().d` -> [Name("a"), Name("b"), FunctionCall("c"), Name("d")]

Returns:
list[FunctionCall | Name]: List of Name nodes (property access) and FunctionCall nodes (method calls)
"""
if isinstance(self.get_name(), ChainedAttribute): # child is chainedAttribute. MEANING that this is likely in the middle or the last function call of a chained function call chain.
return self.get_name().attribute_chain
elif isinstance(
self.parent, ChainedAttribute
): # does not have child chainedAttribute, but parent is chainedAttribute. MEANING that this is likely the TOP function call of a chained function call chain.
return self.parent.attribute_chain
else: # this is a standalone function call
return [self]

@property
@noapidoc
def descendant_symbols(self) -> list[Importable]:
Expand All @@ -603,24 +669,35 @@ def register_api_call(self, url: str):
@property
@reader
def call_chain(self) -> list[FunctionCall]:
"""Returns a list of all function calls in this function call chain, including this call. Does not include calls made after this one."""
"""Returns a list of all function calls in this function call chain, including this call. Does not include calls made after this one."""
ret = []
name = self.get_name()
while name:
if isinstance(name, FunctionCall):
ret.extend(name.call_chain)
break
elif isinstance(name, ChainedAttribute):
name = name.object
else:
break

# backward traversal
curr = self
pred = curr.predecessor
while pred is not None and isinstance(pred, FunctionCall):
ret.insert(0, pred)
pred = pred.predecessor

ret.append(self)

# forward traversal
curr = self
succ = curr.successor
while succ is not None and isinstance(succ, FunctionCall):
ret.append(succ)
succ = succ.successor

return ret

@property
@reader
def base(self) -> Editable | None:
"""Returns the base object of this function call chain."""
"""Returns the base object of this function call chain.

Args:
Editable | None: The base object of this function call chain.
"""
name = self.get_name()
while isinstance(name, ChainedAttribute):
if isinstance(name.object, FunctionCall):
Expand Down
45 changes: 45 additions & 0 deletions src/codegen/sdk/core/expressions/chained_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from codegen.shared.decorators.docs import apidoc, noapidoc

if TYPE_CHECKING:
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
from codegen.sdk.core.interfaces.has_name import HasName
from codegen.sdk.core.interfaces.importable import Importable


Object = TypeVar("Object", bound="Chainable")
Attribute = TypeVar("Attribute", bound="Resolvable")
Parent = TypeVar("Parent", bound="Expression")
Expand Down Expand Up @@ -74,6 +76,49 @@ def attribute(self) -> Attribute:
"""
return self._attribute

@property
@reader
def attribute_chain(self) -> list["FunctionCall | Name"]:
"""Returns a list of elements in a chained attribute expression.

Breaks down chained expressions into individual components in order of appearance.
For example: `a.b.c().d` -> [Name("a"), Name("b"), FunctionCall("c"), Name("d")]

Returns:
list[FunctionCall | Name]: List of Name nodes (property access) and FunctionCall nodes (method calls)
"""
from codegen.sdk.core.detached_symbols.function_call import FunctionCall

ret = []
curr = self

# Traverse backwards in code (children of tree node)
while isinstance(curr, ChainedAttribute):
curr = curr.object

if isinstance(curr, FunctionCall):
ret.insert(0, curr)
curr = curr.get_name()
elif isinstance(curr, ChainedAttribute):
ret.insert(0, curr.attribute)

# This means that we have reached the base of the chain and the first item was an attribute (i.e a.b.c.func())
if isinstance(curr, Name) and not isinstance(curr.parent, FunctionCall):
ret.insert(0, curr)

curr = self

# Traversing forward in code (parents of tree node). Will add the current node as well
while isinstance(curr, ChainedAttribute) or isinstance(curr, FunctionCall):
if isinstance(curr, FunctionCall):
ret.append(curr)
elif isinstance(curr, ChainedAttribute) and not isinstance(curr.parent, FunctionCall):
ret.append(curr.attribute)

curr = curr.parent

return ret

@property
def object(self) -> Object:
"""Returns the object that contains the attribute being looked up.
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/sdk/core/interfaces/editable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
"resolved_types",
"valid_symbol_names",
"valid_import_names",
"predecessor",
"successor",
"base",
"call_chain",
"code_block",
"parent_statement",
"symbol_usages",
Expand Down Expand Up @@ -404,7 +408,7 @@
"""
matches: list[Editable[Self]] = []
for node in self.extended_nodes:
matches.extend(node._find_string_literals(strings_to_match, fuzzy_match))

Check failure on line 411 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "extend" of "list" has incompatible type "Sequence[Editable[Editable[Any]]]"; expected "Iterable[Editable[Self]]" [arg-type]
return matches

@noapidoc
Expand Down Expand Up @@ -507,7 +511,7 @@
# Use search to find string
search_results = itertools.chain.from_iterable(map(self._search, map(re.escape, strings_to_match)))
if exact:
search_results = filter(lambda result: result.source in strings_to_match, search_results)

Check failure on line 514 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible types in assignment (expression has type "filter[Editable[Any]]", variable has type "chain[Editable[Any]]") [assignment]

# Combine and deduplicate results
return list(search_results)
Expand Down Expand Up @@ -897,9 +901,9 @@
if arguments is not None and any(identifier == arg.child_by_field_name("left") for arg in arguments.named_children):
continue

usages.append(self._parse_expression(identifier))

Check failure on line 904 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: "Sequence[Editable[Self]]" has no attribute "append" [attr-defined]

return usages

Check failure on line 906 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "Sequence[Editable[Self]]", expected "list[Editable[Any]]") [return-value]

@reader
def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> Sequence[Editable[Self]]:
Expand Down Expand Up @@ -954,14 +958,14 @@

@commiter
@noapidoc
def _add_symbol_usages(self: HasName, identifiers: list[TSNode], usage_type: UsageKind, dest: HasName | None = None) -> None:

Check failure on line 961 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: The erased type of self "codegen.sdk.core.interfaces.has_name.HasName" is not a supertype of its class "codegen.sdk.core.interfaces.editable.Editable[Parent`1]" [misc]
from codegen.sdk.core.expressions import Name
from codegen.sdk.core.interfaces.resolvable import Resolvable

if dest is None:
dest = self
for x in identifiers:
if dep := self._parse_expression(x, default=Name):

Check failure on line 968 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: "HasName" has no attribute "_parse_expression" [attr-defined]
assert isinstance(dep, Resolvable)
dep._compute_dependencies(usage_type, dest)

Expand All @@ -971,7 +975,7 @@
id_types = self.G.node_classes.resolvables
# Skip identifiers that are part of a property
identifiers = find_all_descendants(self.ts_node, id_types, nested=False)
return self._add_symbol_usages(identifiers, usage_type, dest)

Check failure on line 978 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Invalid self argument "Editable[Parent]" to attribute function "_add_symbol_usages" with type "Callable[[HasName, list[Node], UsageKind, HasName | None], None]" [misc]

@commiter
@noapidoc
Expand All @@ -980,7 +984,7 @@
id_types = self.G.node_classes.resolvables
# Skip identifiers that are part of a property
identifiers = find_all_descendants(child, id_types, nested=False)
return self._add_symbol_usages(identifiers, usage_type, dest)

Check failure on line 987 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Invalid self argument "Editable[Parent]" to attribute function "_add_symbol_usages" with type "Callable[[HasName, list[Node], UsageKind, HasName | None], None]" [misc]

@noapidoc
def _log_parse(self, msg: str, *args, **kwargs):
Expand All @@ -1006,7 +1010,7 @@

@cached_property
@noapidoc
def github_url(self) -> str | None:

Check failure on line 1013 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Missing return statement [return]
if self.file.github_url:
return self.file.github_url + f"#L{self.start_point[0] + 1}-L{self.end_point[0] + 1}"

Expand Down Expand Up @@ -1067,7 +1071,7 @@
dest = self
while dest and not isinstance(dest, Importable):
dest = dest.parent
return dest

Check failure on line 1074 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "Editable[Parent] | Importable[Any]", expected "Importable[Any]") [return-value]

@cached_property
@noapidoc
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/sdk/core/symbol_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, file_node_id: NodeId, G: CodebaseGraph, parent: Parent, node:
node = children[0].ts_node
super().__init__(node, file_node_id, G, parent)

def __repr__(self) -> str:
return f"Collection({self.symbols})" if self.symbols is not None else super().__repr__()

def _init_children(self): ...

@repr_func # HACK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,8 @@ def baz():

# Check call chain
assert c.call_chain == [a, b, c]
assert b.call_chain == [a, b]
assert a.call_chain == [a]
assert b.call_chain == [a, b, c]
assert a.call_chain == [a, b, c]

# Check base
assert c.base == a.get_name()
Expand All @@ -530,15 +530,103 @@ def baz():

# Check call chain
assert c.call_chain == [a, b, c]
assert b.call_chain == [a, b]
assert a.call_chain == [a]
assert b.call_chain == [a, b, c]
assert a.call_chain == [a, b, c]

# Check base
assert c.base.source == "x"
assert b.base.source == "x"
assert a.base.source == "x"


def test_function_call_chain_nested(tmpdir) -> None:
# language=python
content = """
def foo():
# Nested function calls - each call should be independent
a(b(c()))
"""
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
foo = file.get_function("foo")
calls = foo.function_calls
assert len(calls) == 3
a = calls[0]
b = calls[1]
c = calls[2]

# Each call should be independent - no predecessors
assert a.predecessor is None
assert b.predecessor is None
assert c.predecessor is None

# No successors since they're nested, not chained
assert a.successor is None
assert b.successor is None
assert c.successor is None

# Call chain for each should only include itself
assert a.call_chain == [a]
assert b.call_chain == [b]
assert c.call_chain == [c]

# Verify source strings are correct
assert a.source == "a(b(c()))"
assert b.source == "b(c())"
assert c.source == "c()"


def test_function_call_chain_successor(tmpdir) -> None:
# language=python
content = """
def foo():
a().b().c()

def bat():
x.y.z.func()

def baz():
x.a().y.b().z.c()
"""
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")

# Check foo
foo = file.get_function("foo")
calls = foo.function_calls
assert len(calls) == 3
c = calls[0]
b = calls[1]
a = calls[2]

# Check successors
assert a.successor == b
assert b.successor == c
assert c.successor is None

# Check bat
bat = file.get_function("bat")
calls = bat.function_calls
assert len(calls) == 1
func = calls[0]

# No successor since it's a single function call
assert func.successor is None

# Check baz
baz = file.get_function("baz")
calls = baz.function_calls
assert len(calls) == 3
c = calls[0]
b = calls[1]
a = calls[2]

# Check successors
assert a.successor == b
assert b.successor == c
assert c.successor is None


def test_function_call_chain_hard(tmpdir) -> None:
# language=python
content = """
Expand Down
Loading