diff --git a/src/codegen/sdk/codebase/node_classes/node_classes.py b/src/codegen/sdk/codebase/node_classes/node_classes.py index d2fa805ec..f439dc1c3 100644 --- a/src/codegen/sdk/codebase/node_classes/node_classes.py +++ b/src/codegen/sdk/codebase/node_classes/node_classes.py @@ -16,6 +16,7 @@ from codegen.sdk.core.file import SourceFile from codegen.sdk.core.function import Function from codegen.sdk.core.import_resolution import Import + from codegen.sdk.core.interfaces.editable import Editable from codegen.sdk.core.statements.comment import Comment from codegen.sdk.core.symbol import Symbol @@ -33,7 +34,7 @@ class NodeClasses: function_call_cls: type[FunctionCall] comment_cls: type[Comment] bool_conversion: dict[bool, str] - dynamic_import_parent_types: set[str] + dynamic_import_parent_types: set[type[Editable]] symbol_map: dict[str, type[Symbol]] = field(default_factory=dict) expression_map: dict[str, type[Expression]] = field(default_factory=dict) type_map: dict[str, type[Type] | dict[str, type[Type]]] = field(default_factory=dict) diff --git a/src/codegen/sdk/codebase/node_classes/py_node_classes.py b/src/codegen/sdk/codebase/node_classes/py_node_classes.py index 50cb13df4..7f2203f75 100644 --- a/src/codegen/sdk/codebase/node_classes/py_node_classes.py +++ b/src/codegen/sdk/codebase/node_classes/py_node_classes.py @@ -14,7 +14,13 @@ from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression from codegen.sdk.core.expressions.unary_expression import UnaryExpression from codegen.sdk.core.expressions.unpack import Unpack +from codegen.sdk.core.function import Function from codegen.sdk.core.statements.comment import Comment +from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement +from codegen.sdk.core.statements.if_block_statement import IfBlockStatement +from codegen.sdk.core.statements.switch_statement import SwitchStatement +from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement +from codegen.sdk.core.statements.while_statement import WhileStatement from codegen.sdk.core.symbol_groups.dict import Dict from codegen.sdk.core.symbol_groups.list import List from codegen.sdk.core.symbol_groups.tuple import Tuple @@ -29,6 +35,8 @@ from codegen.sdk.python.expressions.string import PyString from codegen.sdk.python.expressions.union_type import PyUnionType from codegen.sdk.python.statements.import_statement import PyImportStatement +from codegen.sdk.python.statements.match_case import PyMatchCase +from codegen.sdk.python.statements.with_statement import WithStatement def parse_subscript(node: TSNode, file_node_id, ctx, parent): @@ -110,16 +118,13 @@ def parse_subscript(node: TSNode, file_node_id, ctx, parent): False: "False", }, dynamic_import_parent_types={ - "function_definition", - "if_statement", - "try_statement", - "with_statement", - "else_clause", - "for_statement", - "except_clause", - "while_statement", - "match_statement", - "case_clause", - "finally_clause", + Function, + IfBlockStatement, + TryCatchStatement, + WithStatement, + ForLoopStatement, + WhileStatement, + SwitchStatement, + PyMatchCase, }, ) diff --git a/src/codegen/sdk/codebase/node_classes/ts_node_classes.py b/src/codegen/sdk/codebase/node_classes/ts_node_classes.py index ac3690c22..e1d4515c2 100644 --- a/src/codegen/sdk/codebase/node_classes/ts_node_classes.py +++ b/src/codegen/sdk/codebase/node_classes/ts_node_classes.py @@ -15,7 +15,14 @@ from codegen.sdk.core.expressions.unary_expression import UnaryExpression from codegen.sdk.core.expressions.unpack import Unpack from codegen.sdk.core.expressions.value import Value +from codegen.sdk.core.function import Function from codegen.sdk.core.statements.comment import Comment +from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement +from codegen.sdk.core.statements.if_block_statement import IfBlockStatement +from codegen.sdk.core.statements.switch_case import SwitchCase +from codegen.sdk.core.statements.switch_statement import SwitchStatement +from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement +from codegen.sdk.core.statements.while_statement import WhileStatement from codegen.sdk.core.symbol_groups.list import List from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters from codegen.sdk.typescript.class_definition import TSClass @@ -166,18 +173,12 @@ def parse_new(node: TSNode, *args): False: "false", }, dynamic_import_parent_types={ - "function_declaration", - "method_definition", - "arrow_function", - "if_statement", - "try_statement", - "else_clause", - "catch_clause", - "finally_clause", - "while_statement", - "for_statement", - "do_statement", - "switch_case", - "switch_statement", + Function, + IfBlockStatement, + TryCatchStatement, + ForLoopStatement, + WhileStatement, + SwitchStatement, + SwitchCase, }, ) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 7c93a386e..b282942ae 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -459,15 +459,10 @@ def parse(self, ctx: CodebaseContext) -> None: self.code_block = self._parse_code_block(self.ts_node) self.code_block.parse() - self._parse_imports() # We need to clear the valid symbol/import names before we start resolving exports since these can be outdated. self.invalidate() sort_editables(self._nodes) - @abstractmethod - @commiter - def _parse_imports(self) -> None: ... - @noapidoc @commiter def remove_internal_edges(self) -> None: diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py index f0afb6e7a..a824d8a72 100644 --- a/src/codegen/sdk/core/import_resolution.py +++ b/src/codegen/sdk/core/import_resolution.py @@ -428,15 +428,7 @@ def my_function(): bool: True if the import is dynamic (within a control flow or scope block), False if it's a top-level import. """ - curr = self.ts_node - - # always traverses upto the module level - while curr: - if curr.type in self.ctx.node_classes.dynamic_import_parent_types: - return True - curr = curr.parent - - return False + return self.parent_of_types(self.ctx.node_classes.dynamic_import_parent_types) is not None #################################################################################################################### # MANIPULATIONS diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index 57c6a6ba1..28da7a6c3 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -823,7 +823,7 @@ def children_by_field_types(self, field_types: str | Iterable[str]) -> Generator @reader @noapidoc def child_by_field_types(self, field_types: str | Iterable[str]) -> Expression[Self] | None: - """Get child by field types.""" + """Get child by fiexld types.""" return next(self.children_by_field_types(field_types), None) @property @@ -1097,6 +1097,14 @@ def parent_of_type(self, type: type[T]) -> T | None: return self.parent.parent_of_type(type) return None + def parent_of_types(self, types: set[type[T]]) -> T | None: + """Find the first ancestor of the node of the given type. Does not return itself""" + if self.parent and any(isinstance(self.parent, t) for t in types): + return self.parent + if self.parent is not self and self.parent is not None: + return self.parent.parent_of_types(types) + return None + @reader def ancestors(self, type: type[T]) -> list[T]: """Find all ancestors of the node of the given type. Does not return itself""" diff --git a/src/codegen/sdk/core/parser.py b/src/codegen/sdk/core/parser.py index cfa5f64f3..5ea44f27e 100644 --- a/src/codegen/sdk/core/parser.py +++ b/src/codegen/sdk/core/parser.py @@ -8,7 +8,7 @@ from codegen.sdk.core.expressions.placeholder_type import PlaceholderType from codegen.sdk.core.expressions.value import Value from codegen.sdk.core.statements.symbol_statement import SymbolStatement -from codegen.sdk.utils import find_first_function_descendant +from codegen.sdk.utils import find_first_function_descendant, find_import_node if TYPE_CHECKING: from tree_sitter import Node as TSNode @@ -108,6 +108,7 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC from codegen.sdk.typescript.statements.comment import TSComment from codegen.sdk.typescript.statements.for_loop_statement import TSForLoopStatement from codegen.sdk.typescript.statements.if_block_statement import TSIfBlockStatement + from codegen.sdk.typescript.statements.import_statement import TSImportStatement from codegen.sdk.typescript.statements.labeled_statement import TSLabeledStatement from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement from codegen.sdk.typescript.statements.try_catch_statement import TSTryCatchStatement @@ -117,11 +118,13 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC if node.type in self.expressions or node.type == "expression_statement": return [ExpressionStatement(node, file_node_id, ctx, parent, 0, expression_node=node)] + for child in node.named_children: # =====[ Functions + Methods ]===== if child.type in _VALID_TYPE_NAMES: statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) - + elif child.type == "import_statement": + statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements))) # =====[ Classes ]===== elif child.type in ("class_declaration", "abstract_class_declaration"): statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) @@ -132,7 +135,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC # =====[ Type Alias Declarations ]===== elif child.type == "type_alias_declaration": - statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) + if import_node := find_import_node(child): + statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node)) + else: + statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements))) # =====[ Enum Declarations ]===== elif child.type == "enum_declaration": @@ -142,11 +148,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC elif child.type == "export_statement" or child.text.decode("utf-8") == "export *;": statements.append(ExportStatement(child, file_node_id, ctx, parent, len(statements))) - # =====[ Imports ] ===== - elif child.type == "import_statement": - # statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements))) - pass # Temporarily opting to identify all imports using find_all_descendants - # =====[ Non-symbol statements ] ===== elif child.type == "comment": statements.append(TSComment.from_code_block(child, parent, pos=len(statements))) @@ -167,6 +168,8 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC elif child.type in ["lexical_declaration", "variable_declaration"]: if function_node := find_first_function_descendant(child): statements.append(SymbolStatement(child, file_node_id, ctx, parent, len(statements), function_node)) + elif import_node := find_import_node(child): + statements.append(TSImportStatement(child, file_node_id, ctx, parent, len(statements), source_node=import_node)) else: statements.append( TSAssignmentStatement.from_assignment( @@ -176,6 +179,10 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC elif child.type in ["public_field_definition", "property_signature", "enum_assignment"]: statements.append(TSAttribute(child, file_node_id, ctx, parent, pos=len(statements))) elif child.type == "expression_statement": + if import_node := find_import_node(child): + statements.append(TSImportStatement(child, file_node_id, ctx, parent, pos=len(statements), source_node=import_node)) + continue + for var in child.named_children: if var.type == "string": statements.append(TSComment.from_code_block(var, parent, pos=len(statements))) @@ -185,7 +192,6 @@ def parse_ts_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC statements.append(ExpressionStatement(child, file_node_id, ctx, parent, pos=len(statements), expression_node=var)) elif child.type in self.expressions: statements.append(ExpressionStatement(child, file_node_id, ctx, parent, len(statements), expression_node=child)) - else: self.log("Couldn't parse statement with type: %s", child.type) statements.append(Statement.from_code_block(child, parent, pos=len(statements))) @@ -204,6 +210,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC from codegen.sdk.python.statements.comment import PyComment from codegen.sdk.python.statements.for_loop_statement import PyForLoopStatement from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement + from codegen.sdk.python.statements.import_statement import PyImportStatement from codegen.sdk.python.statements.match_statement import PyMatchStatement from codegen.sdk.python.statements.pass_statement import PyPassStatement from codegen.sdk.python.statements.try_catch_statement import PyTryCatchStatement @@ -237,9 +244,7 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, ctx: CodebaseC # =====[ Imports ] ===== elif child.type in ["import_statement", "import_from_statement", "future_import_statement"]: - # statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements))) - pass # Temporarily opting to identify all imports using find_all_descendants - + statements.append(PyImportStatement(child, file_node_id, ctx, parent, len(statements))) # =====[ Non-symbol statements ] ===== elif child.type == "comment": statements.append(PyComment.from_code_block(child, parent, pos=len(statements))) diff --git a/src/codegen/sdk/extensions/utils.pyi b/src/codegen/sdk/extensions/utils.pyi index 952bdd0ef..c75f0ce03 100644 --- a/src/codegen/sdk/extensions/utils.pyi +++ b/src/codegen/sdk/extensions/utils.pyi @@ -13,6 +13,7 @@ def find_all_descendants( type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True, + stop_at_first: str | None = None, ) -> list[TSNode]: ... def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]: """Returns a list of tuples of the start and end nodes of each line in the node""" diff --git a/src/codegen/sdk/extensions/utils.pyx b/src/codegen/sdk/extensions/utils.pyx index 992db3663..73da6ce59 100644 --- a/src/codegen/sdk/extensions/utils.pyx +++ b/src/codegen/sdk/extensions/utils.pyx @@ -31,7 +31,7 @@ def get_all_identifiers(node: TSNode) -> list[TSNode]: return sorted(dict.fromkeys(identifiers), key=lambda x: x.start_byte) -def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True) -> list[TSNode]: +def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True, stop_at_first: str | None = None) -> list[TSNode]: if isinstance(type_names, str): type_names = [type_names] descendants = [] @@ -45,6 +45,9 @@ def find_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_dept if not nested and current_node != node: return + if stop_at_first and current_node.type == stop_at_first: + return + for child in current_node.children: traverse(child, depth + 1) diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index cec7dc4d3..3c92feaef 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -2,11 +2,11 @@ from typing import TYPE_CHECKING -from codegen.sdk.core.autocommit import commiter, reader, writer +from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.file import SourceFile from codegen.sdk.core.interface import Interface from codegen.sdk.enums import ImportType -from codegen.sdk.extensions.utils import cached_property, iter_all_descendants +from codegen.sdk.extensions.utils import cached_property from codegen.sdk.python import PyAssignment from codegen.sdk.python.class_definition import PyClass from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock @@ -15,7 +15,6 @@ from codegen.sdk.python.import_resolution import PyImport from codegen.sdk.python.interfaces.has_block import PyHasBlock from codegen.sdk.python.statements.attribute import PyAttribute -from codegen.sdk.python.statements.import_statement import PyImportStatement from codegen.shared.decorators.docs import noapidoc, py_apidoc from codegen.shared.enums.programming_language import ProgrammingLanguage @@ -59,12 +58,6 @@ def symbol_can_be_added(self, symbol: PySymbol) -> bool: """ return True - @noapidoc - @commiter - def _parse_imports(self) -> None: - for import_node in iter_all_descendants(self.ts_node, frozenset({"import_statement", "import_from_statement", "future_import_statement"})): - PyImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0) - #################################################################################################################### # GETTERS #################################################################################################################### diff --git a/src/codegen/sdk/typescript/file.py b/src/codegen/sdk/typescript/file.py index 810a708f9..4c937292d 100644 --- a/src/codegen/sdk/typescript/file.py +++ b/src/codegen/sdk/typescript/file.py @@ -3,7 +3,7 @@ import os from typing import TYPE_CHECKING -from codegen.sdk.core.autocommit import commiter, mover, reader, writer +from codegen.sdk.core.autocommit import mover, reader, writer from codegen.sdk.core.file import SourceFile from codegen.sdk.core.interfaces.exportable import Exportable from codegen.sdk.enums import ImportType, NodeType, SymbolType @@ -18,8 +18,7 @@ from codegen.sdk.typescript.interface import TSInterface from codegen.sdk.typescript.interfaces.has_block import TSHasBlock from codegen.sdk.typescript.namespace import TSNamespace -from codegen.sdk.typescript.statements.import_statement import TSImportStatement -from codegen.sdk.utils import calculate_base_path, find_all_descendants +from codegen.sdk.utils import calculate_base_path from codegen.shared.decorators.docs import noapidoc, ts_apidoc from codegen.shared.enums.programming_language import ProgrammingLanguage @@ -228,18 +227,6 @@ def add_export_to_symbol(self, symbol: TSSymbol) -> None: # TODO: this should be in symbol.py class. Rename as `add_export` symbol.add_keyword("export") - @noapidoc - @commiter - def _parse_imports(self) -> None: - import_nodes = find_all_descendants(self.ts_node, {"import_statement", "call_expression"}) - for import_node in import_nodes: - if import_node.type == "import_statement": - TSImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0) - elif import_node.type == "call_expression": - function = import_node.child_by_field_name("function") - if function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require"): - TSImportStatement(import_node, self.node_id, self.ctx, self.code_block, 0) - @writer def remove_unused_exports(self) -> None: """Removes unused exports from the file. diff --git a/src/codegen/sdk/typescript/import_resolution.py b/src/codegen/sdk/typescript/import_resolution.py index 5954c5304..090c63b24 100644 --- a/src/codegen/sdk/typescript/import_resolution.py +++ b/src/codegen/sdk/typescript/import_resolution.py @@ -451,7 +451,10 @@ def from_dynamic_import_statement(cls, import_call_node: TSNode, module_node: TS return imports # If import statement is a variable declaration, capture the variable scoping keyword (const, let, var, etc) - statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node + if import_statement_node.type == "lexical_declaration": + statement_node = import_statement_node + else: + statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node # ==== [ Named dynamic import ] ==== if name_node.type == "property_identifier": diff --git a/src/codegen/sdk/typescript/statements/import_statement.py b/src/codegen/sdk/typescript/statements/import_statement.py index cde889d8e..a54070588 100644 --- a/src/codegen/sdk/typescript/statements/import_statement.py +++ b/src/codegen/sdk/typescript/statements/import_statement.py @@ -35,9 +35,9 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, imports = [] if ts_node.type == "import_statement": imports.extend(TSImport.from_import_statement(ts_node, file_node_id, ctx, self)) - elif ts_node.type == "call_expression": - import_call_node = ts_node.child_by_field_name("function") - arguments = ts_node.child_by_field_name("arguments") + elif ts_node.type in ["call_expression", "lexical_declaration", "expression_statement", "type_alias_declaration"]: + import_call_node = source_node.child_by_field_name("function") + arguments = source_node.child_by_field_name("arguments") imports.extend(TSImport.from_dynamic_import_statement(import_call_node, arguments, file_node_id, ctx, self)) elif ts_node.type == "export_statement": imports.extend(TSImport.from_export_statement(source_node, file_node_id, ctx, self)) diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py index 4049ae118..913782f1f 100644 --- a/src/codegen/sdk/utils.py +++ b/src/codegen/sdk/utils.py @@ -87,6 +87,45 @@ def find_first_function_descendant(node: TSNode) -> TSNode: return find_first_descendant(node=node, type_names=type_names, max_depth=2) +def find_import_node(node: TSNode) -> TSNode | None: + """Get the import node from a node that may contain an import. + Returns None if the node does not contain an import. + + Returns: + TSNode | None: The import_statement or call_expression node if it's an import, None otherwise + """ + # Static imports + if node.type == "import_statement": + return node + + # Dynamic imports and requires can be either: + # 1. Inside expression_statement -> call_expression + # 2. Direct call_expression + + # we only parse imports inside expressions and variable declarations + + # import_nodes = [_node for _node in find_all_descendants(node, ["call_expression", "statement_block"], nested=False) if _node.type == "call_expression"] + + if member_expression := find_first_descendant(node, ["member_expression"]): + # there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module) + descendants = find_all_descendants(member_expression, ["call_expression"], stop_at_first="statement_block") + if descendants: + import_node = descendants[-1] + else: + # this means this is NOT a dynamic import() + return None + else: + import_node = find_first_descendant(node, ["call_expression"]) + + # thus we only consider the deepest one + if import_node: + function = import_node.child_by_field_name("function") + if function and (function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require")): + return import_node + + return None + + def find_index(target: TSNode, siblings: list[TSNode]) -> int: """Returns the index of the target node in the list of siblings, or -1 if not found. Recursive implementation.""" if target in siblings: diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py b/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py index f8625ee7c..705a84c5a 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py @@ -1,4 +1,10 @@ +from codegen import Codebase from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.sdk.core.function import Function +from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement +from codegen.sdk.core.statements.if_block_statement import IfBlockStatement +from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement +from codegen.sdk.python.statements.with_statement import WithStatement from codegen.shared.enums.programming_language import ProgrammingLanguage @@ -225,3 +231,83 @@ def test_py_import_is_dynamic_in_match_case(tmpdir): assert imports[1].is_dynamic # dynamic_in_case import assert imports[2].is_dynamic # from x import y assert imports[3].is_dynamic # another_dynamic import + + +def test_parent_of_types_function(): + codebase = Codebase.from_string( + """ + def hello(): + import foo + """, + language="python", + ) + import_stmt = codebase.files[0].imports[0] + assert import_stmt.parent_of_types({Function}) is not None + assert import_stmt.parent_of_types({IfBlockStatement}) is None + + +def test_parent_of_types_if_statement(): + codebase = Codebase.from_string( + """ + if True: + import foo + """, + language="python", + ) + import_stmt = codebase.files[0].imports[0] + assert import_stmt.parent_of_types({IfBlockStatement}) is not None + assert import_stmt.parent_of_types({Function}) is None + + +def test_parent_of_types_multiple(): + codebase = Codebase.from_string( + """ + def hello(): + if True: + import foo + """, + language="python", + ) + import_stmt = codebase.files[0].imports[0] + # Should find both Function and IfBlockStatement parents + assert import_stmt.parent_of_types({Function, IfBlockStatement}) is not None + # Should find closest parent first (IfBlockStatement) + assert isinstance(import_stmt.parent_of_types({Function, IfBlockStatement}), IfBlockStatement) + + +def test_parent_of_types_try_catch(): + codebase = Codebase.from_string( + """ + try: + import foo + except: + pass + """, + language="python", + ) + import_stmt = codebase.files[0].imports[0] + assert import_stmt.parent_of_types({TryCatchStatement}) is not None + + +def test_parent_of_types_with(): + codebase = Codebase.from_string( + """ + with open('file.txt') as f: + import foo + """, + language="python", + ) + import_stmt = codebase.files[0].imports[0] + assert import_stmt.parent_of_types({WithStatement}) is not None + + +def test_parent_of_types_for_loop(): + codebase = Codebase.from_string( + """ + for i in range(10): + import foo + """, + language="python", + ) + import_stmt = codebase.files[0].imports[0] + assert import_stmt.parent_of_types({ForLoopStatement}) is not None diff --git a/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py b/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py index 8855e531b..5cbfcc7f6 100644 --- a/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py +++ b/tests/unit/codegen/sdk/typescript/import_resolution/test_import_resolution_resolve_import.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING +import pytest + from codegen.sdk.codebase.factory.get_session import get_codebase_session from codegen.sdk.core.dataclasses.usage import UsageType from codegen.sdk.enums import ImportType @@ -257,6 +259,7 @@ def test_dynamic_import_type_alias(tmpdir) -> None: assert file.get_import("RequiredDefaultType").resolved_symbol == m_file.get_interface("SomeInterface") +@pytest.mark.xfail(reason="Currently dynamic imports not supported for type parameters") def test_dynamic_import_function_param_type(tmpdir) -> None: # language=typescript content = """ diff --git a/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py b/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py index 3c221bfe4..2984953d1 100644 --- a/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py +++ b/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py @@ -1,4 +1,9 @@ +from codegen import Codebase from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.sdk.core.function import Function +from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement +from codegen.sdk.core.statements.if_block_statement import IfBlockStatement +from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement from codegen.shared.enums.programming_language import ProgrammingLanguage @@ -35,6 +40,7 @@ class MyComponent { async decoratedMethod() { const module = await import('./decorated'); } + } """ with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: @@ -53,7 +59,9 @@ def test_ts_import_is_dynamic_in_arrow_function(tmpdir): const MyComponent = () => { const loadModule = async () => { - const module = await import('./lazy'); + const module = await import('./lazy').then((module) => { + return module.default; + }); }; return