Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/codegen/sdk/core/detached_symbols/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def __init__(self, node: TSNode, positional_idx: int, parent: FunctionCall) -> N
self._name_node = self._parse_expression(name_node, default=Name)
self._value_node = self._parse_expression(_value_node)

def __repr__(self) -> str:
keyword = f"keyword={self.name}, " if self.name else ""
value = f"value='{self.value}', " if self.value else ""
type = f"type={self.type}" if self.type else ""

return f"Argument({keyword}{value}{type})"

@noapidoc
@classmethod
def from_argument_list(cls, node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: FunctionCall) -> MultiExpression[Parent, Argument]:
Expand Down
98 changes: 85 additions & 13 deletions src/codegen/sdk/core/detached_symbols/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def __init__(self, node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent:
args = [Argument(x, i, self) for i, x in enumerate(arg_list_node.named_children) if x.type != "comment"]
self._arg_list = Collection(arg_list_node, self.file_node_id, self.G, self, children=args)

def __repr__(self) -> str:
"""Custom string representation showing the function call chain structure.

Format: FunctionCall(name=current, pred=pred_name, succ=succ_name, base=base_name)

It will only print out predecessor, successor, and base that are of type FunctionCall. If it's a property, it will not be logged
"""
# Helper to safely get name

# Get names for each part
parts = [f"name='{self.name}'"]

if self.predecessor and isinstance(self.predecessor, FunctionCall):
parts.append(f"predecessor=FunctionCall(name='{self.predecessor.name}')")

if self.successor and isinstance(self.successor, FunctionCall):
parts.append(f"successor=FunctionCall(name='{self.successor.name}')")

parts.append(f"filepath='{self.file.filepath}'")

return f"FunctionCall({', '.join(parts)})"

@classmethod
def from_usage(cls, node: Editable[Parent], parent: Parent | None = None) -> Self | None:
"""Creates a FunctionCall object from an Editable instance that represents a function call.
Expand Down Expand Up @@ -210,9 +232,36 @@ def predecessor(self) -> FunctionCall[Parent] | None:
or if the predecessor is not a function call.
"""
# Recursively travel down the tree to find the previous function call (child nodes are previous calls)
return self.call_chain[-2] if len(self.call_chain) > 1 else None
name = self.get_name()
while name:
if isinstance(name, FunctionCall):
return name
elif isinstance(name, ChainedAttribute):
name = name.object
else:
break
return None

@property
@reader
def successor(self) -> FunctionCall[Parent] | Name | None:
"""Returns the next function call in a function call chain.

Returns the next function call in a function call chain. This method is useful for traversing function call chains
to analyze or modify sequences of chained function calls.

# TODO: also define a successor?
Args:
None

Returns:
FunctionCall[Parent] | Name | None: The next function call in the chain, or None if there is no successor
or if the successor is not a function call.
"""
# this will avoid parent function calls in tree-sitter that are NOT part of the chained calls
if not isinstance(self.parent, ChainedAttribute):
return None

return self.parent_of_type(FunctionCall)

@property
@noapidoc
Expand Down Expand Up @@ -579,6 +628,18 @@ def function_calls(self) -> list[FunctionCall]:
# calls.append(call)
return sort_editables(calls, dedupe=False)

@property
@reader
def attribute_chain(self) -> list[FunctionCall | Name]:
if isinstance(self.get_name(), ChainedAttribute): # child is chainedAttribute. MEANING that this is likely in the middle or the last function call of a chained function call chain.
return self.get_name().attribute_chain
elif isinstance(
self.parent, ChainedAttribute
): # does not have child chainedAttribute, but parent is chainedAttribute. MEANING that this is likely the TOP function call of a chained function call chain.
return self.parent.attribute_chain
else: # this is a standalone function call
return [self]

@property
@noapidoc
def descendant_symbols(self) -> list[Importable]:
Expand All @@ -601,24 +662,35 @@ def register_api_call(self, url: str):
@property
@reader
def call_chain(self) -> list[FunctionCall]:
"""Returns a list of all function calls in this function call chain, including this call. Does not include calls made after this one."""
"""Returns a list of all function calls in this function call chain, including this call. Does not include calls made after this one."""
ret = []
name = self.get_name()
while name:
if isinstance(name, FunctionCall):
ret.extend(name.call_chain)
break
elif isinstance(name, ChainedAttribute):
name = name.object
else:
break

# backward traversal
curr = self
pred = curr.predecessor
while pred is not None and isinstance(pred, FunctionCall):
ret.insert(0, pred)
pred = pred.predecessor

ret.append(self)

# forward traversal
curr = self
succ = curr.successor
while succ is not None and isinstance(succ, FunctionCall):
ret.append(succ)
succ = succ.successor

return ret

@property
@reader
def base(self) -> Editable | None:
"""Returns the base object of this function call chain."""
"""Returns the base object of this function call chain.

Args:
Editable | None: The base object of this function call chain.
"""
name = self.get_name()
while isinstance(name, ChainedAttribute):
if isinstance(name.object, FunctionCall):
Expand Down
37 changes: 37 additions & 0 deletions src/codegen/sdk/core/expressions/chained_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from codegen.shared.decorators.docs import apidoc, noapidoc

if TYPE_CHECKING:
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
from codegen.sdk.core.interfaces.has_name import HasName
from codegen.sdk.core.interfaces.importable import Importable


Object = TypeVar("Object", bound="Chainable")
Attribute = TypeVar("Attribute", bound="Resolvable")
Parent = TypeVar("Parent", bound="Expression")
Expand Down Expand Up @@ -74,6 +76,41 @@ def attribute(self) -> Attribute:
"""
return self._attribute

@property
@reader
def attribute_chain(self) -> list["FunctionCall | Name"]:
from codegen.sdk.core.detached_symbols.function_call import FunctionCall

ret = []
curr = self

# Traverse backwards in code (children of tree node)
while isinstance(curr, ChainedAttribute):
curr = curr.object

if isinstance(curr, FunctionCall):
ret.insert(0, curr)
curr = curr.get_name()
elif isinstance(curr, ChainedAttribute):
ret.insert(0, curr.attribute)

# This means that we have reached the base of the chain and the first item was an attribute (i.e a.b.c.func())
if isinstance(curr, Name) and not isinstance(curr.parent, FunctionCall):
ret.insert(0, curr)

curr = self

# Traversing forward in code (parents of tree node). Will add the current node as well
while isinstance(curr, ChainedAttribute) or isinstance(curr, FunctionCall):
if isinstance(curr, FunctionCall):
ret.append(curr)
elif isinstance(curr, ChainedAttribute) and not isinstance(curr.parent, FunctionCall):
ret.append(curr.attribute)

curr = curr.parent

return ret

@property
def object(self) -> Object:
"""Returns the object that contains the attribute being looked up.
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/sdk/core/interfaces/editable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
"resolved_types",
"valid_symbol_names",
"valid_import_names",
"predecessor",
"successor",
"base",
"call_chain",
"code_block",
"parent_statement",
"symbol_usages",
Expand Down Expand Up @@ -404,7 +408,7 @@
"""
matches: list[Editable[Self]] = []
for node in self.extended_nodes:
matches.extend(node._find_string_literals(strings_to_match, fuzzy_match))

Check failure on line 411 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Argument 1 to "extend" of "list" has incompatible type "Sequence[Editable[Editable[Any]]]"; expected "Iterable[Editable[Self]]" [arg-type]
return matches

@noapidoc
Expand Down Expand Up @@ -507,7 +511,7 @@
# Use search to find string
search_results = itertools.chain.from_iterable(map(self._search, map(re.escape, strings_to_match)))
if exact:
search_results = filter(lambda result: result.source in strings_to_match, search_results)

Check failure on line 514 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible types in assignment (expression has type "filter[Editable[Any]]", variable has type "chain[Editable[Any]]") [assignment]

# Combine and deduplicate results
return list(search_results)
Expand Down Expand Up @@ -897,9 +901,9 @@
if arguments is not None and any(identifier == arg.child_by_field_name("left") for arg in arguments.named_children):
continue

usages.append(self._parse_expression(identifier))

Check failure on line 904 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: "Sequence[Editable[Self]]" has no attribute "append" [attr-defined]

return usages

Check failure on line 906 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "Sequence[Editable[Self]]", expected "list[Editable[Any]]") [return-value]

@reader
def get_variable_usages(self, var_name: str, fuzzy_match: bool = False) -> Sequence[Editable[Self]]:
Expand Down Expand Up @@ -954,14 +958,14 @@

@commiter
@noapidoc
def _add_symbol_usages(self: HasName, identifiers: list[TSNode], usage_type: UsageKind, dest: HasName | None = None) -> None:

Check failure on line 961 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: The erased type of self "codegen.sdk.core.interfaces.has_name.HasName" is not a supertype of its class "codegen.sdk.core.interfaces.editable.Editable[Parent`1]" [misc]
from codegen.sdk.core.expressions import Name
from codegen.sdk.core.interfaces.resolvable import Resolvable

if dest is None:
dest = self
for x in identifiers:
if dep := self._parse_expression(x, default=Name):

Check failure on line 968 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: "HasName" has no attribute "_parse_expression" [attr-defined]
assert isinstance(dep, Resolvable)
dep._compute_dependencies(usage_type, dest)

Expand All @@ -971,7 +975,7 @@
id_types = self.G.node_classes.resolvables
# Skip identifiers that are part of a property
identifiers = find_all_descendants(self.ts_node, id_types, nested=False)
return self._add_symbol_usages(identifiers, usage_type, dest)

Check failure on line 978 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Invalid self argument "Editable[Parent]" to attribute function "_add_symbol_usages" with type "Callable[[HasName, list[Node], UsageKind, HasName | None], None]" [misc]

@commiter
@noapidoc
Expand All @@ -980,7 +984,7 @@
id_types = self.G.node_classes.resolvables
# Skip identifiers that are part of a property
identifiers = find_all_descendants(child, id_types, nested=False)
return self._add_symbol_usages(identifiers, usage_type, dest)

Check failure on line 987 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Invalid self argument "Editable[Parent]" to attribute function "_add_symbol_usages" with type "Callable[[HasName, list[Node], UsageKind, HasName | None], None]" [misc]

@noapidoc
def _log_parse(self, msg: str, *args, **kwargs):
Expand All @@ -1006,7 +1010,7 @@

@cached_property
@noapidoc
def github_url(self) -> str | None:

Check failure on line 1013 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Missing return statement [return]
if self.file.github_url:
return self.file.github_url + f"#L{self.start_point[0] + 1}-L{self.end_point[0] + 1}"

Expand Down Expand Up @@ -1067,7 +1071,7 @@
dest = self
while dest and not isinstance(dest, Importable):
dest = dest.parent
return dest

Check failure on line 1074 in src/codegen/sdk/core/interfaces/editable.py

View workflow job for this annotation

GitHub Actions / mypy

error: Incompatible return value type (got "Editable[Parent] | Importable[Any]", expected "Importable[Any]") [return-value]

@cached_property
@noapidoc
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/sdk/core/symbol_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, file_node_id: NodeId, G: CodebaseGraph, parent: Parent, node:
node = children[0].ts_node
super().__init__(node, file_node_id, G, parent)

def __repr__(self) -> str:
return f"Collection({self.symbols})" if self.symbols is not None else super().__repr__()

def _init_children(self): ...

@repr_func # HACK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,8 @@ def baz():

# Check call chain
assert c.call_chain == [a, b, c]
assert b.call_chain == [a, b]
assert a.call_chain == [a]
assert b.call_chain == [a, b, c]
assert a.call_chain == [a, b, c]

# Check base
assert c.base == a.get_name()
Expand All @@ -530,15 +530,103 @@ def baz():

# Check call chain
assert c.call_chain == [a, b, c]
assert b.call_chain == [a, b]
assert a.call_chain == [a]
assert b.call_chain == [a, b, c]
assert a.call_chain == [a, b, c]

# Check base
assert c.base.source == "x"
assert b.base.source == "x"
assert a.base.source == "x"


def test_function_call_chain_nested(tmpdir) -> None:
# language=python
content = """
def foo():
# Nested function calls - each call should be independent
a(b(c()))
"""
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
foo = file.get_function("foo")
calls = foo.function_calls
assert len(calls) == 3
a = calls[0]
b = calls[1]
c = calls[2]

# Each call should be independent - no predecessors
assert a.predecessor is None
assert b.predecessor is None
assert c.predecessor is None

# No successors since they're nested, not chained
assert a.successor is None
assert b.successor is None
assert c.successor is None

# Call chain for each should only include itself
assert a.call_chain == [a]
assert b.call_chain == [b]
assert c.call_chain == [c]

# Verify source strings are correct
assert a.source == "a(b(c()))"
assert b.source == "b(c())"
assert c.source == "c()"


def test_function_call_chain_successor(tmpdir) -> None:
# language=python
content = """
def foo():
a().b().c()

def bat():
x.y.z.func()

def baz():
x.a().y.b().z.c()
"""
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")

# Check foo
foo = file.get_function("foo")
calls = foo.function_calls
assert len(calls) == 3
c = calls[0]
b = calls[1]
a = calls[2]

# Check successors
assert a.successor == b
assert b.successor == c
assert c.successor is None

# Check bat
bat = file.get_function("bat")
calls = bat.function_calls
assert len(calls) == 1
func = calls[0]

# No successor since it's a single function call
assert func.successor is None

# Check baz
baz = file.get_function("baz")
calls = baz.function_calls
assert len(calls) == 3
c = calls[0]
b = calls[1]
a = calls[2]

# Check successors
assert a.successor == b
assert b.successor == c
assert c.successor is None


def test_function_call_chain_hard(tmpdir) -> None:
# language=python
content = """
Expand Down
Loading
Loading