diff --git a/src/codegen/sdk/codebase/node_classes/node_classes.py b/src/codegen/sdk/codebase/node_classes/node_classes.py
index d2fa805ec..f439dc1c3 100644
--- a/src/codegen/sdk/codebase/node_classes/node_classes.py
+++ b/src/codegen/sdk/codebase/node_classes/node_classes.py
@@ -16,6 +16,7 @@
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.function import Function
from codegen.sdk.core.import_resolution import Import
+ from codegen.sdk.core.interfaces.editable import Editable
from codegen.sdk.core.statements.comment import Comment
from codegen.sdk.core.symbol import Symbol
@@ -33,7 +34,7 @@ class NodeClasses:
function_call_cls: type[FunctionCall]
comment_cls: type[Comment]
bool_conversion: dict[bool, str]
- dynamic_import_parent_types: set[str]
+ dynamic_import_parent_types: set[type[Editable]]
symbol_map: dict[str, type[Symbol]] = field(default_factory=dict)
expression_map: dict[str, type[Expression]] = field(default_factory=dict)
type_map: dict[str, type[Type] | dict[str, type[Type]]] = field(default_factory=dict)
diff --git a/src/codegen/sdk/codebase/node_classes/py_node_classes.py b/src/codegen/sdk/codebase/node_classes/py_node_classes.py
index 50cb13df4..7f2203f75 100644
--- a/src/codegen/sdk/codebase/node_classes/py_node_classes.py
+++ b/src/codegen/sdk/codebase/node_classes/py_node_classes.py
@@ -14,7 +14,13 @@
from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
from codegen.sdk.core.expressions.unpack import Unpack
+from codegen.sdk.core.function import Function
from codegen.sdk.core.statements.comment import Comment
+from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
+from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
+from codegen.sdk.core.statements.switch_statement import SwitchStatement
+from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
+from codegen.sdk.core.statements.while_statement import WhileStatement
from codegen.sdk.core.symbol_groups.dict import Dict
from codegen.sdk.core.symbol_groups.list import List
from codegen.sdk.core.symbol_groups.tuple import Tuple
@@ -29,6 +35,8 @@
from codegen.sdk.python.expressions.string import PyString
from codegen.sdk.python.expressions.union_type import PyUnionType
from codegen.sdk.python.statements.import_statement import PyImportStatement
+from codegen.sdk.python.statements.match_case import PyMatchCase
+from codegen.sdk.python.statements.with_statement import WithStatement
def parse_subscript(node: TSNode, file_node_id, ctx, parent):
@@ -110,16 +118,13 @@ def parse_subscript(node: TSNode, file_node_id, ctx, parent):
False: "False",
},
dynamic_import_parent_types={
- "function_definition",
- "if_statement",
- "try_statement",
- "with_statement",
- "else_clause",
- "for_statement",
- "except_clause",
- "while_statement",
- "match_statement",
- "case_clause",
- "finally_clause",
+ Function,
+ IfBlockStatement,
+ TryCatchStatement,
+ WithStatement,
+ ForLoopStatement,
+ WhileStatement,
+ SwitchStatement,
+ PyMatchCase,
},
)
diff --git a/src/codegen/sdk/codebase/node_classes/ts_node_classes.py b/src/codegen/sdk/codebase/node_classes/ts_node_classes.py
index ac3690c22..e1d4515c2 100644
--- a/src/codegen/sdk/codebase/node_classes/ts_node_classes.py
+++ b/src/codegen/sdk/codebase/node_classes/ts_node_classes.py
@@ -15,7 +15,14 @@
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
from codegen.sdk.core.expressions.unpack import Unpack
from codegen.sdk.core.expressions.value import Value
+from codegen.sdk.core.function import Function
from codegen.sdk.core.statements.comment import Comment
+from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
+from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
+from codegen.sdk.core.statements.switch_case import SwitchCase
+from codegen.sdk.core.statements.switch_statement import SwitchStatement
+from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
+from codegen.sdk.core.statements.while_statement import WhileStatement
from codegen.sdk.core.symbol_groups.list import List
from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters
from codegen.sdk.typescript.class_definition import TSClass
@@ -166,18 +173,12 @@ def parse_new(node: TSNode, *args):
False: "false",
},
dynamic_import_parent_types={
- "function_declaration",
- "method_definition",
- "arrow_function",
- "if_statement",
- "try_statement",
- "else_clause",
- "catch_clause",
- "finally_clause",
- "while_statement",
- "for_statement",
- "do_statement",
- "switch_case",
- "switch_statement",
+ Function,
+ IfBlockStatement,
+ TryCatchStatement,
+ ForLoopStatement,
+ WhileStatement,
+ SwitchStatement,
+ SwitchCase,
},
)
diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py
index 7c93a386e..b282942ae 100644
--- a/src/codegen/sdk/core/file.py
+++ b/src/codegen/sdk/core/file.py
@@ -459,15 +459,10 @@ def parse(self, ctx: CodebaseContext) -> None:
self.code_block = self._parse_code_block(self.ts_node)
self.code_block.parse()
- self._parse_imports()
# We need to clear the valid symbol/import names before we start resolving exports since these can be outdated.
self.invalidate()
sort_editables(self._nodes)
- @abstractmethod
- @commiter
- def _parse_imports(self) -> None: ...
-
@noapidoc
@commiter
def remove_internal_edges(self) -> None:
diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py
index f0afb6e7a..a824d8a72 100644
--- a/src/codegen/sdk/core/import_resolution.py
+++ b/src/codegen/sdk/core/import_resolution.py
@@ -428,15 +428,7 @@ def my_function():
bool: True if the import is dynamic (within a control flow or scope block),
False if it's a top-level import.
"""
- curr = self.ts_node
-
- # always traverses upto the module level
- while curr:
- if curr.type in self.ctx.node_classes.dynamic_import_parent_types:
- return True
- curr = curr.parent
-
- return False
+ return self.parent_of_types(self.ctx.node_classes.dynamic_import_parent_types) is not None
####################################################################################################################
# MANIPULATIONS
diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py
index 57c6a6ba1..28da7a6c3 100644
--- a/src/codegen/sdk/core/interfaces/editable.py
+++ b/src/codegen/sdk/core/interfaces/editable.py
@@ -823,7 +823,7 @@ def children_by_field_types(self, field_types: str | Iterable[str]) -> Generator
@reader
@noapidoc
def child_by_field_types(self, field_types: str | Iterable[str]) -> Expression[Self] | None:
- """Get child by field types."""
+ """Get child by fiexld types."""
return next(self.children_by_field_types(field_types), None)
@property
@@ -1097,6 +1097,14 @@ def parent_of_type(self, type: type[T]) -> T | None:
return self.parent.parent_of_type(type)
return None
+ def parent_of_types(self, types: set[type[T]]) -> T | None:
+ """Find the first ancestor of the node of the given type. Does not return itself"""
+ if self.parent and any(isinstance(self.parent, t) for t in types):
+ return self.parent
+ if self.parent is not self and self.parent is not None:
+ return self.parent.parent_of_types(types)
+ return None
+
@reader
def ancestors(self, type: type[T]) -> list[T]:
"""Find all ancestors of the node of the given type. Does not return itself"""
diff --git a/src/codegen/sdk/core/parser.py b/src/codegen/sdk/core/parser.py
index cfa5f64f3..5ea44f27e 100644
--- a/src/codegen/sdk/core/parser.py
+++ b/src/codegen/sdk/core/parser.py
@@ -8,7 +8,7 @@
from codegen.sdk.core.expressions.placeholder_type import PlaceholderType
from codegen.sdk.core.expressions.value import Value
from codegen.sdk.core.statements.symbol_statement import SymbolStatement
-from codegen.sdk.utils import find_first_function_descendant
+from codegen.sdk.utils import find_first_function_descendant, find_import_node
if TYPE_CHECKING:
from tree_sitter import Node as TSNode
@@ -108,6 +108,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
from codegen.sdk.typescript.statements.comment import TSComment
from codegen.sdk.typescript.statements.for_loop_statement import TSForLoopStatement
from codegen.sdk.typescript.statements.if_block_statement import TSIfBlockStatement
+ from codegen.sdk.typescript.statements.import_statement import TSImportStatement
from codegen.sdk.typescript.statements.labeled_statement import TSLabeledStatement
from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement
from codegen.sdk.typescript.statements.try_catch_statement import TSTryCatchStatement
@@ -117,11 +118,13 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
if node.type in self.expressions or node.type == "expression_statement":
return [ExpressionStatement(node, file_node_id, ctx, parent, 0, expression_node=node)]
+
for child in node.named_children:
# =====[ Functions + Methods ]=====
if child.type in _VALID_TYPE_NAMES:
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
-
+ elif child.type == "import_statement":
+ statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
# =====[ Classes ]=====
elif child.type in ("class_declaration", "abstract_class_declaration"):
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
@@ -132,7 +135,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
# =====[ Type Alias Declarations ]=====
elif child.type == "type_alias_declaration":
- statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
+ if import_node := find_import_node(child):
+ statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node))
+ else:
+ statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements)))
# =====[ Enum Declarations ]=====
elif child.type == "enum_declaration":
@@ -142,11 +148,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
elif child.type == "export_statement" or child.text.decode("utf-8") == "export *;":
statements.append(ExportStatement(child, file_node_id, ctx, parent, len(statements)))
- # =====[ Imports ] =====
- elif child.type == "import_statement":
- # statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements)))
- pass # Temporarily opting to identify all imports using find_all_descendants
-
# =====[ Non-symbol statements ] =====
elif child.type == "comment":
statements.append(TSComment.from_code_block(child, parent, pos=len(statements)))
@@ -167,6 +168,8 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
elif child.type in ["lexical_declaration", "variable_declaration"]:
if function_node := find_first_function_descendant(child):
statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements), function_node))
+ elif import_node := find_import_node(child):
+ statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node))
else:
statements.append(
TSAssignmentStatement.from_assignment(
@@ -176,6 +179,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
elif child.type in ["public_field_definition", "property_signature", "enum_assignment"]:
statements.append(TSAttribute(child, file_node_id, ctx, parent, pos=len(statements)))
elif child.type == "expression_statement":
+ if import_node := find_import_node(child):
+ statements.append(TSImportStatement(child, file_node_id, ctx, parent, pos=len(statements), source_node=import_node))
+ continue
+
for var in child.named_children:
if var.type == "string":
statements.append(TSComment.from_code_block(var, parent, pos=len(statements)))
@@ -185,7 +192,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
statements.append(ExpressionStatement(child, file_node_id, ctx, parent, pos=len(statements), expression_node=var))
elif child.type in self.expressions:
statements.append(ExpressionStatement(child, file_node_id, ctx, parent, len(statements), expression_node=child))
-
else:
self.log("Couldn't parse statement with type: %s", child.type)
statements.append(Statement.from_code_block(child, parent, pos=len(statements)))
@@ -204,6 +210,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
from codegen.sdk.python.statements.comment import PyComment
from codegen.sdk.python.statements.for_loop_statement import PyForLoopStatement
from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement
+ from codegen.sdk.python.statements.import_statement import PyImportStatement
from codegen.sdk.python.statements.match_statement import PyMatchStatement
from codegen.sdk.python.statements.pass_statement import PyPassStatement
from codegen.sdk.python.statements.try_catch_statement import PyTryCatchStatement
@@ -237,9 +244,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC
# =====[ Imports ] =====
elif child.type in ["import_statement", "import_from_statement", "future_import_statement"]:
- # statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements)))
- pass # Temporarily opting to identify all imports using find_all_descendants
-
+ statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements)))
# =====[ Non-symbol statements ] =====
elif child.type == "comment":
statements.append(PyComment.from_code_block(child, parent, pos=len(statements)))
diff --git a/src/codegen/sdk/extensions/utils.pyi b/src/codegen/sdk/extensions/utils.pyi
index 952bdd0ef..c75f0ce03 100644
--- a/src/codegen/sdk/extensions/utils.pyi
+++ b/src/codegen/sdk/extensions/utils.pyi
@@ -13,6 +13,7 @@ def find_all_descendants(
type_names: Iterable[str] | str,
max_depth: int | None = None,
nested: bool = True,
+ stop_at_first: str | None = None,
) -> list[TSNode]: ...
def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]:
"""Returns a list of tuples of the start and end nodes of each line in the node"""
diff --git a/src/codegen/sdk/extensions/utils.pyx b/src/codegen/sdk/extensions/utils.pyx
index 992db3663..73da6ce59 100644
--- a/src/codegen/sdk/extensions/utils.pyx
+++ b/src/codegen/sdk/extensions/utils.pyx
@@ -31,7 +31,7 @@ def get_all_identifiers(node: TSNode) -> list[TSNode]:
return sorted(dict.fromkeys(identifiers), key=lambda x: x.start_byte)
-def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True) -> list[TSNode]:
+def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True, stop_at_first: str | None = None) -> list[TSNode]:
if isinstance(type_names, str):
type_names = [type_names]
descendants = []
@@ -45,6 +45,9 @@ def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_dept
if not nested and current_node != node:
return
+ if stop_at_first and current_node.type == stop_at_first:
+ return
+
for child in current_node.children:
traverse(child, depth + 1)
diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py
index cec7dc4d3..3c92feaef 100644
--- a/src/codegen/sdk/python/file.py
+++ b/src/codegen/sdk/python/file.py
@@ -2,11 +2,11 @@
from typing import TYPE_CHECKING
-from codegen.sdk.core.autocommit import commiter, reader, writer
+from codegen.sdk.core.autocommit import reader, writer
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.interface import Interface
from codegen.sdk.enums import ImportType
-from codegen.sdk.extensions.utils import cached_property, iter_all_descendants
+from codegen.sdk.extensions.utils import cached_property
from codegen.sdk.python import PyAssignment
from codegen.sdk.python.class_definition import PyClass
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock
@@ -15,7 +15,6 @@
from codegen.sdk.python.import_resolution import PyImport
from codegen.sdk.python.interfaces.has_block import PyHasBlock
from codegen.sdk.python.statements.attribute import PyAttribute
-from codegen.sdk.python.statements.import_statement import PyImportStatement
from codegen.shared.decorators.docs import noapidoc, py_apidoc
from codegen.shared.enums.programming_language import ProgrammingLanguage
@@ -59,12 +58,6 @@ def symbol_can_be_added(self, symbol: PySymbol) -> bool:
"""
return True
- @noapidoc
- @commiter
- def _parse_imports(self) -> None:
- for import_node in iter_all_descendants(self.ts_node, frozenset({"import_statement", "import_from_statement", "future_import_statement"})):
- PyImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0)
-
####################################################################################################################
# GETTERS
####################################################################################################################
diff --git a/src/codegen/sdk/typescript/file.py b/src/codegen/sdk/typescript/file.py
index 810a708f9..4c937292d 100644
--- a/src/codegen/sdk/typescript/file.py
+++ b/src/codegen/sdk/typescript/file.py
@@ -3,7 +3,7 @@
import os
from typing import TYPE_CHECKING
-from codegen.sdk.core.autocommit import commiter, mover, reader, writer
+from codegen.sdk.core.autocommit import mover, reader, writer
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.interfaces.exportable import Exportable
from codegen.sdk.enums import ImportType, NodeType, SymbolType
@@ -18,8 +18,7 @@
from codegen.sdk.typescript.interface import TSInterface
from codegen.sdk.typescript.interfaces.has_block import TSHasBlock
from codegen.sdk.typescript.namespace import TSNamespace
-from codegen.sdk.typescript.statements.import_statement import TSImportStatement
-from codegen.sdk.utils import calculate_base_path, find_all_descendants
+from codegen.sdk.utils import calculate_base_path
from codegen.shared.decorators.docs import noapidoc, ts_apidoc
from codegen.shared.enums.programming_language import ProgrammingLanguage
@@ -228,18 +227,6 @@ def add_export_to_symbol(self, symbol: TSSymbol) -> None:
# TODO: this should be in symbol.py class. Rename as `add_export`
symbol.add_keyword("export")
- @noapidoc
- @commiter
- def _parse_imports(self) -> None:
- import_nodes = find_all_descendants(self.ts_node, {"import_statement", "call_expression"})
- for import_node in import_nodes:
- if import_node.type == "import_statement":
- TSImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0)
- elif import_node.type == "call_expression":
- function = import_node.child_by_field_name("function")
- if function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require"):
- TSImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0)
-
@writer
def remove_unused_exports(self) -> None:
"""Removes unused exports from the file.
diff --git a/src/codegen/sdk/typescript/import_resolution.py b/src/codegen/sdk/typescript/import_resolution.py
index 5954c5304..090c63b24 100644
--- a/src/codegen/sdk/typescript/import_resolution.py
+++ b/src/codegen/sdk/typescript/import_resolution.py
@@ -451,7 +451,10 @@ def from_dynamic_import_statement(cls, import_call_node: TSNode, module_node: TS
return imports
# If import statement is a variable declaration, capture the variable scoping keyword (const, let, var, etc)
- statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node
+ if import_statement_node.type == "lexical_declaration":
+ statement_node = import_statement_node
+ else:
+ statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node
# ==== [ Named dynamic import ] ====
if name_node.type == "property_identifier":
diff --git a/src/codegen/sdk/typescript/statements/import_statement.py b/src/codegen/sdk/typescript/statements/import_statement.py
index cde889d8e..a54070588 100644
--- a/src/codegen/sdk/typescript/statements/import_statement.py
+++ b/src/codegen/sdk/typescript/statements/import_statement.py
@@ -35,9 +35,9 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext,
imports = []
if ts_node.type == "import_statement":
imports.extend(TSImport.from_import_statement(ts_node, file_node_id, ctx, self))
- elif ts_node.type == "call_expression":
- import_call_node = ts_node.child_by_field_name("function")
- arguments = ts_node.child_by_field_name("arguments")
+ elif ts_node.type in ["call_expression", "lexical_declaration", "expression_statement", "type_alias_declaration"]:
+ import_call_node = source_node.child_by_field_name("function")
+ arguments = source_node.child_by_field_name("arguments")
imports.extend(TSImport.from_dynamic_import_statement(import_call_node, arguments, file_node_id, ctx, self))
elif ts_node.type == "export_statement":
imports.extend(TSImport.from_export_statement(source_node, file_node_id, ctx, self))
diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py
index 4049ae118..913782f1f 100644
--- a/src/codegen/sdk/utils.py
+++ b/src/codegen/sdk/utils.py
@@ -87,6 +87,45 @@ def find_first_function_descendant(node: TSNode) -> TSNode:
return find_first_descendant(node=node, type_names=type_names, max_depth=2)
+def find_import_node(node: TSNode) -> TSNode | None:
+ """Get the import node from a node that may contain an import.
+ Returns None if the node does not contain an import.
+
+ Returns:
+ TSNode | None: The import_statement or call_expression node if it's an import, None otherwise
+ """
+ # Static imports
+ if node.type == "import_statement":
+ return node
+
+ # Dynamic imports and requires can be either:
+ # 1. Inside expression_statement -> call_expression
+ # 2. Direct call_expression
+
+ # we only parse imports inside expressions and variable declarations
+
+ # import_nodes = [_node for _node in find_all_descendants(node, ["call_expression", "statement_block"], nested=False) if _node.type == "call_expression"]
+
+ if member_expression := find_first_descendant(node, ["member_expression"]):
+ # there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module)
+ descendants = find_all_descendants(member_expression, ["call_expression"], stop_at_first="statement_block")
+ if descendants:
+ import_node = descendants[-1]
+ else:
+ # this means this is NOT a dynamic import()
+ return None
+ else:
+ import_node = find_first_descendant(node, ["call_expression"])
+
+ # thus we only consider the deepest one
+ if import_node:
+ function = import_node.child_by_field_name("function")
+ if function and (function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require")):
+ return import_node
+
+ return None
+
+
def find_index(target: TSNode, siblings: list[TSNode]) -> int:
"""Returns the index of the target node in the list of siblings, or -1 if not found. Recursive implementation."""
if target in siblings:
diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py b/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py
index f8625ee7c..705a84c5a 100644
--- a/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py
+++ b/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py
@@ -1,4 +1,10 @@
+from codegen import Codebase
from codegen.sdk.codebase.factory.get_session import get_codebase_session
+from codegen.sdk.core.function import Function
+from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
+from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
+from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
+from codegen.sdk.python.statements.with_statement import WithStatement
from codegen.shared.enums.programming_language import ProgrammingLanguage
@@ -225,3 +231,83 @@ def test_py_import_is_dynamic_in_match_case(tmpdir):
assert imports[1].is_dynamic # dynamic_in_case import
assert imports[2].is_dynamic # from x import y
assert imports[3].is_dynamic # another_dynamic import
+
+
+def test_parent_of_types_function():
+ codebase = Codebase.from_string(
+ """
+ def hello():
+ import foo
+ """,
+ language="python",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({Function}) is not None
+ assert import_stmt.parent_of_types({IfBlockStatement}) is None
+
+
+def test_parent_of_types_if_statement():
+ codebase = Codebase.from_string(
+ """
+ if True:
+ import foo
+ """,
+ language="python",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({IfBlockStatement}) is not None
+ assert import_stmt.parent_of_types({Function}) is None
+
+
+def test_parent_of_types_multiple():
+ codebase = Codebase.from_string(
+ """
+ def hello():
+ if True:
+ import foo
+ """,
+ language="python",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ # Should find both Function and IfBlockStatement parents
+ assert import_stmt.parent_of_types({Function, IfBlockStatement}) is not None
+ # Should find closest parent first (IfBlockStatement)
+ assert isinstance(import_stmt.parent_of_types({Function, IfBlockStatement}), IfBlockStatement)
+
+
+def test_parent_of_types_try_catch():
+ codebase = Codebase.from_string(
+ """
+ try:
+ import foo
+ except:
+ pass
+ """,
+ language="python",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({TryCatchStatement}) is not None
+
+
+def test_parent_of_types_with():
+ codebase = Codebase.from_string(
+ """
+ with open('file.txt') as f:
+ import foo
+ """,
+ language="python",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({WithStatement}) is not None
+
+
+def test_parent_of_types_for_loop():
+ codebase = Codebase.from_string(
+ """
+ for i in range(10):
+ import foo
+ """,
+ language="python",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({ForLoopStatement}) is not None
diff --git a/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py b/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py
index 8855e531b..5cbfcc7f6 100644
--- a/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py
+++ b/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING
+import pytest
+
from codegen.sdk.codebase.factory.get_session import get_codebase_session
from codegen.sdk.core.dataclasses.usage import UsageType
from codegen.sdk.enums import ImportType
@@ -257,6 +259,7 @@ def test_dynamic_import_type_alias(tmpdir) -> None:
assert file.get_import("RequiredDefaultType").resolved_symbol == m_file.get_interface("SomeInterface")
+@pytest.mark.xfail(reason="Currently dynamic imports not supported for type parameters")
def test_dynamic_import_function_param_type(tmpdir) -> None:
# language=typescript
content = """
diff --git a/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py b/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py
index 3c221bfe4..2984953d1 100644
--- a/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py
+++ b/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py
@@ -1,4 +1,9 @@
+from codegen import Codebase
from codegen.sdk.codebase.factory.get_session import get_codebase_session
+from codegen.sdk.core.function import Function
+from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
+from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
+from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
from codegen.shared.enums.programming_language import ProgrammingLanguage
@@ -35,6 +40,7 @@ class MyComponent {
async decoratedMethod() {
const module = await import('./decorated');
}
+
}
"""
with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase:
@@ -53,7 +59,9 @@ def test_ts_import_is_dynamic_in_arrow_function(tmpdir):
const MyComponent = () => {
const loadModule = async () => {
- const module = await import('./lazy');
+ const module = await import('./lazy').then((module) => {
+ return module.default;
+ });
};
return ;
@@ -190,51 +198,73 @@ def test_ts_import_is_dynamic_in_for_statement(tmpdir):
assert imports[1].is_dynamic # dynamic import in for loop
-def test_ts_import_is_dynamic_in_do_statement(tmpdir):
- # language=typescript
- content = """
- import { shouldContinue } from './utils';
-
- do {
- const module = await import('./dynamic-module');
- module.process();
- } while (shouldContinue());
- """
- with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase:
- file = codebase.get_file("test.ts")
- imports = file.imports
-
- assert not imports[0].is_dynamic # static import
- assert imports[1].is_dynamic # dynamic import in do-while loop
-
-
-def test_ts_import_is_dynamic_in_switch_statement(tmpdir):
- # language=typescript
- content = """
- import { getFeatureFlag } from './utils';
-
- switch (getFeatureFlag()) {
- case 'feature1':
- import('./feature1').then(module => {
- module.init();
- });
- break;
- case 'feature2':
- import('./feature2').then(module => {
- module.init();
- });
- break;
- default:
- import('./default').then(module => {
- module.init();
- });
- }
- """
- with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase:
- file = codebase.get_file("test.ts")
- imports = file.imports
-
- assert not imports[0].is_dynamic # static import
- assert imports[1].is_dynamic # dynamic import in first case
- assert imports[2].is_dynamic # dynamic import in second case
- assert imports[3].is_dynamic # dynamic import in default case
+def test_parent_of_types_function():
+ codebase = Codebase.from_string(
+ """
+ function hello() {
+ import { foo } from 'bar';
+ }
+ """,
+ language="typescript",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({Function}) is not None
+ assert import_stmt.parent_of_types({IfBlockStatement}) is None
+
+
+def test_parent_of_types_if_statement():
+ codebase = Codebase.from_string(
+ """
+ if (true) {
+ import { foo } from 'bar';
+ }
+ """,
+ language="typescript",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({IfBlockStatement}) is not None
+ assert import_stmt.parent_of_types({Function}) is None
+
+
+def test_parent_of_types_multiple():
+ codebase = Codebase.from_string(
+ """
+ function hello() {
+ if (true) {
+ import { foo } from 'bar';
+ }
+ }
+ """,
+ language="typescript",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ # Should find both Function and IfBlockStatement parents
+ assert import_stmt.parent_of_types({Function, IfBlockStatement}) is not None
+ # Should find closest parent first (IfBlockStatement)
+ assert isinstance(import_stmt.parent_of_types({Function, IfBlockStatement}), IfBlockStatement)
+
+
+def test_parent_of_types_try_catch():
+ codebase = Codebase.from_string(
+ """
+ try {
+ import { foo } from 'bar';
+ } catch (e) {}
+ """,
+ language="typescript",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({TryCatchStatement}) is not None
+
+
+def test_parent_of_types_for_loop():
+ codebase = Codebase.from_string(
+ """
+ for (let i = 0; i < 10; i++) {
+ import { foo } from 'bar';
+ }
+ """,
+ language="typescript",
+ )
+ import_stmt = codebase.files[0].imports[0]
+ assert import_stmt.parent_of_types({ForLoopStatement}) is not None
diff --git a/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json b/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json
index 622bdb147..7eae0a269 100644
--- a/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json
+++ b/tests/unit/skills/snapshots/test_skills/test_all_example_skills/call_graph_filter-PYTHON-case-0/call_graph_filter_unnamed.json
@@ -22,7 +22,7 @@
],
"file_path": "path/to/file1.py",
"symbol_name": "PyFunction",
- "id": 17
+ "id": 18
},
{
"name": "MyClass.get",
@@ -157,7 +157,7 @@
],
"file_path": "path/to/file.py",
"symbol_name": "PyClass",
- "source": 17,
+ "source": 18,
"target": "range= filepath='path/to/file1.py'"
},
{
@@ -177,7 +177,7 @@
],
"file_path": "path/to/file.py",
"symbol_name": "PyClass",
- "source": 17,
+ "source": 18,
"target": "range= filepath='path/to/file1.py'"
},
{
@@ -197,7 +197,7 @@
],
"file_path": "path/to/file.py",
"symbol_name": "PyClass",
- "source": 17,
+ "source": 18,
"target": "range= filepath='path/to/file1.py'"
},
{
@@ -217,7 +217,7 @@
],
"file_path": "path/to/file.py",
"symbol_name": "PyClass",
- "source": 17,
+ "source": 18,
"target": "range= filepath='path/to/file1.py'"
},
{
diff --git a/tests/unit/skills/snapshots/test_skills/test_all_example_skills/dead_code-PYTHON-case-0/dead_code_unnamed.json b/tests/unit/skills/snapshots/test_skills/test_all_example_skills/dead_code-PYTHON-case-0/dead_code_unnamed.json
index 926ee5c12..48290f7dd 100644
--- a/tests/unit/skills/snapshots/test_skills/test_all_example_skills/dead_code-PYTHON-case-0/dead_code_unnamed.json
+++ b/tests/unit/skills/snapshots/test_skills/test_all_example_skills/dead_code-PYTHON-case-0/dead_code_unnamed.json
@@ -41,7 +41,7 @@
],
"file_path": "decorated_functions.py",
"symbol_name": "PyFunction",
- "id": 3
+ "id": 4
},
{
"name": "unused_function",