diff --git a/codegen-examples/examples/snapshot_event_handler/pr_tasks.py b/codegen-examples/examples/snapshot_event_handler/pr_tasks.py index 153e303f0..0a412ec30 100644 --- a/codegen-examples/examples/snapshot_event_handler/pr_tasks.py +++ b/codegen-examples/examples/snapshot_event_handler/pr_tasks.py @@ -1,8 +1,6 @@ import logging -from codegen.agents.code_agent import CodeAgent from codegen.extensions.github.types.pull_request import PullRequestLabeledEvent -from codegen.extensions.langchain.tools import GithubCreatePRCommentTool, GithubCreatePRReviewCommentTool, GithubViewPRTool from codegen.sdk.core.codebase import Codebase logging.basicConfig(level=logging.INFO, force=True) @@ -11,25 +9,22 @@ def lint_for_dev_import_violations(codebase: Codebase, event: PullRequestLabeledEvent): # Next.js codemod to detect imports of the react-dev-overlay module in production code - + patch, commit_shas, modified_symbols = codebase.get_modified_symbols_in_pr(event.pull_request.number) modified_files = set(commit_shas.keys()) from codegen.sdk.core.statements.if_block_statement import IfBlockStatement - DIR_NAME = 'packages/next/src/client/components/react-dev-overlay' + DIR_NAME = "packages/next/src/client/components/react-dev-overlay" directory = codebase.get_directory(DIR_NAME) violations = [] - false_operators = ["!=", "!=="] true_operators = ["===", "=="] - - def is_valid_block_expression(if_block: IfBlockStatement) -> bool: """Check if the if block has a valid environment check condition. - + Valid conditions are: - process.env.NODE_ENV !== 'production' - process.env.NODE_ENV != 'production' @@ -38,46 +33,43 @@ def is_valid_block_expression(if_block: IfBlockStatement) -> bool: """ if not if_block.is_if_statement: return False - + condition = if_block.condition # Get the operator without any whitespace operator = condition.operator[-1].source - + # Check for non-production conditions if operator in false_operators and condition.source == f"process.env.NODE_ENV {operator} 'production'": return True - + # Check for explicit development conditions if operator in true_operators and condition.source == f"process.env.NODE_ENV {operator} 'development'": return True - - return False + return False def process_else_block_expression(else_block: IfBlockStatement) -> bool: """Check if the else block is valid by checking its parent if block. - + Valid when the parent if block checks for production environment: - if (process.env.NODE_ENV === 'production') { ... } else { } - if (process.env.NODE_ENV == 'production') { ... } else { } """ if not else_block.is_else_statement: return False - + main_if = else_block._main_if_block if not main_if or not main_if.condition: return False - + condition = main_if.condition operator = condition.operator[-1].source - + # Valid if the main if block checks for production return operator in true_operators and condition.source == f"process.env.NODE_ENV {operator} 'production'" - for file in directory.files(recursive=True): for imp in file.inbound_imports: - if imp.file.filepath not in modified_files: # skip if the import is not in the pull request's modified files continue @@ -85,13 +77,13 @@ def process_else_block_expression(else_block: IfBlockStatement) -> bool: if directory.dirpath in imp.file.filepath: # "✅ Valid import" if the import is within the target directory continue - + parent_if_block = imp.parent_of_type(IfBlockStatement) - + # Check if import is in a valid environment check block if_block_valid = parent_if_block and is_valid_block_expression(parent_if_block) else_block_valid = parent_if_block and process_else_block_expression(parent_if_block) - + # Skip if the import is properly guarded by environment checks if if_block_valid or else_block_valid: # "✅ Valid import" these are guarded by non prod checks @@ -102,7 +94,6 @@ def process_else_block_expression(else_block: IfBlockStatement) -> bool: violations.append(violation) logger.info(f"Found violation: {violation}") - if violations: # Comment on PR with violations review_attention_message = "## Dev Import Violations Found\n\n" @@ -111,4 +102,4 @@ def process_else_block_expression(else_block: IfBlockStatement) -> bool: review_attention_message += "\n\nPlease ensure that development imports are not imported in production code." # Create PR comment with the formatted message - codebase._op.create_pr_comment(event.pull_request.number, review_attention_message) \ No newline at end of file + codebase._op.create_pr_comment(event.pull_request.number, review_attention_message) diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index df5ef6872..78554a5b7 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -50,6 +50,61 @@ def rename_if_matching(self, old: str, new: str): if self.source == old: self.edit(new) + @noapidoc + def _resolve_conditionals(self, conditional_parent: ConditionalBlock, name: str, original_resolved): + """Resolves name references within conditional blocks by traversing the conditional chain. + + This method handles name resolution within conditional blocks (like if/elif/else statements) by: + 1. Finding the appropriate search boundary based on the conditional block's position + 2. Handling "fake" conditionals by traversing up the conditional chain + 3. Yielding resolved names while respecting conditional block boundaries + + Args: + conditional_parent (ConditionalBlock): The parent conditional block containing the name reference + name (str): The name being resolved + original_resolved: The originally resolved symbol that triggered this resolution + + Yields: + Symbol | Import | WildcardImport: Resolved symbols found within the conditional blocks + + Notes: + - A "fake" conditional is one where is_true_conditional() returns False + - The search_limit ensures we don't resolve names that appear after our target + - The method stops when it either: + a) Reaches the top of the conditional chain + b) Returns to the original conditional block + c) Can't find any more resolutions + """ + search_limit = conditional_parent.start_byte_for_condition_block + if search_limit >= original_resolved.start_byte: + search_limit = original_resolved.start_byte - 1 + if not conditional_parent.is_true_conditional(original_resolved): + # If it's a fake conditional we must skip any potential enveloping conditionals + def get_top_of_fake_chain(conditional, resolved, search_limit=0): + if skip_fake := conditional.parent_of_type(ConditionalBlock): + if skip_fake.is_true_conditional(resolved): + return skip_fake.start_byte_for_condition_block + search_limit = skip_fake.start_byte_for_condition_block + return get_top_of_fake_chain(skip_fake, conditional, search_limit) + return search_limit + + if search_limit := get_top_of_fake_chain(conditional_parent, original_resolved): + search_limit = search_limit + else: + return + + original_conditional = conditional_parent + while next_resolved := next(conditional_parent.resolve_name(name, start_byte=search_limit, strict=False), None): + yield next_resolved + next_conditional = next_resolved.parent_of_type(ConditionalBlock) + if not next_conditional or next_conditional == original_conditional: + return + search_limit = next_conditional.start_byte_for_condition_block + if next_conditional and not next_conditional.is_true_conditional(original_resolved): + pass + if search_limit >= next_resolved.start_byte: + search_limit = next_resolved.start_byte - 1 + @noapidoc @reader def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]: @@ -60,14 +115,8 @@ def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = return if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)): - top_of_conditional = conditional_parent.start_byte if self.parent_of_type(ConditionalBlock) == conditional_parent: # Use in the same block, should only depend on the inside of the block return - for other_conditional in conditional_parent.other_possible_blocks: - if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None): - if cond_name.start_byte >= other_conditional.start_byte: - yield cond_name - top_of_conditional = min(top_of_conditional, other_conditional.start_byte) - yield from self.resolve_name(name, top_of_conditional, strict=False) + yield from self._resolve_conditionals(conditional_parent=conditional_parent, name=name, original_resolved=resolved_name) diff --git a/src/codegen/sdk/core/interfaces/conditional_block.py b/src/codegen/sdk/core/interfaces/conditional_block.py index 2689badc3..1f9de6e19 100644 --- a/src/codegen/sdk/core/interfaces/conditional_block.py +++ b/src/codegen/sdk/core/interfaces/conditional_block.py @@ -6,7 +6,9 @@ class ConditionalBlock(Statement, ABC): - """An interface for any code block that might not be executed in the code, e.g if block/else block/try block/catch block ect.""" + """An interface for any code block that might not be executed in the code, + e.g if block/else block, try block/catch block ect. + """ @property @abstractmethod @@ -19,3 +21,17 @@ def other_possible_blocks(self) -> Sequence["ConditionalBlock"]: def end_byte_for_condition_block(self) -> int: """Returns the end byte for the specific condition block""" return self.end_byte + + @property + @noapidoc + def start_byte_for_condition_block(self) -> int: + """Returns the start byte for the specific condition block""" + return self.start_byte + + @noapidoc + def is_true_conditional(self, descendant) -> bool: + """Returns if this conditional is truly conditional, + this is necessary as an override for things like finally + statements that share a parent with try blocks + """ + return True diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index 22ae37f51..86e08c844 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -1106,6 +1106,15 @@ def parent_of_types(self, types: set[type[T]]) -> T | None: return self.parent.parent_of_types(types) return None + def is_child_of(self, instance: Editable) -> bool: + """Checks if this node is a descendant of the given editable instance in the AST.""" + if not self.parent: + return False + if self.parent is instance: + return True + else: + return self.parent.is_child_of(instance=instance) + @reader def ancestors(self, type: type[T]) -> list[T]: """Find all ancestors of the node of the given type. Does not return itself""" diff --git a/src/codegen/sdk/core/statements/if_block_statement.py b/src/codegen/sdk/core/statements/if_block_statement.py index 5d6a99fe7..98e7fab8d 100644 --- a/src/codegen/sdk/core/statements/if_block_statement.py +++ b/src/codegen/sdk/core/statements/if_block_statement.py @@ -299,3 +299,10 @@ def end_byte_for_condition_block(self) -> int: if self.is_if_statement: return self.consequence_block.end_byte return self.end_byte + + @property + @noapidoc + def start_byte_for_condition_block(self) -> int: + if self.is_if_statement: + return self.consequence_block.start_byte + return self.start_byte diff --git a/src/codegen/sdk/core/statements/try_catch_statement.py b/src/codegen/sdk/core/statements/try_catch_statement.py index 177ddde68..eca344b61 100644 --- a/src/codegen/sdk/core/statements/try_catch_statement.py +++ b/src/codegen/sdk/core/statements/try_catch_statement.py @@ -1,13 +1,13 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar, override from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_block import HasBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.core.statements.statement import StatementType -from codegen.shared.decorators.docs import apidoc +from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: from codegen.sdk.core.detached_symbols.code_block import CodeBlock @@ -27,3 +27,26 @@ class TryCatchStatement(ConditionalBlock, BlockStatement[Parent], HasBlock, ABC, statement_type = StatementType.TRY_CATCH_STATEMENT finalizer: BlockStatement | None = None + + @noapidoc + @override + def is_true_conditional(self, descendant) -> bool: + if descendant.is_child_of(self.finalizer): + return False + return True + + @property + @noapidoc + def end_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.end_byte + else: + return self.end_byte + + @property + @noapidoc + def start_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.start_byte - 1 + else: + return self.start_byte diff --git a/src/codegen/sdk/python/statements/try_catch_statement.py b/src/codegen/sdk/python/statements/try_catch_statement.py index fda319130..9d02300cf 100644 --- a/src/codegen/sdk/python/statements/try_catch_statement.py +++ b/src/codegen/sdk/python/statements/try_catch_statement.py @@ -104,11 +104,3 @@ def nested_code_blocks(self) -> list[PyCodeBlock]: @noapidoc def other_possible_blocks(self) -> Sequence[ConditionalBlock]: return self.except_clauses - - @property - @noapidoc - def end_byte_for_condition_block(self) -> int: - if self.code_block: - return self.code_block.end_byte - else: - return self.end_byte diff --git a/src/codegen/sdk/typescript/statements/try_catch_statement.py b/src/codegen/sdk/typescript/statements/try_catch_statement.py index 315f9f33c..947ed3fbd 100644 --- a/src/codegen/sdk/typescript/statements/try_catch_statement.py +++ b/src/codegen/sdk/typescript/statements/try_catch_statement.py @@ -36,7 +36,7 @@ class TSTryCatchStatement(TryCatchStatement["TSCodeBlock"], TSBlockStatement): def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) if handler_node := self.ts_node.child_by_field_name("handler"): - self.catch = TSCatchStatement(handler_node, file_node_id, ctx, self.code_block) + self.catch = TSCatchStatement(handler_node, file_node_id, ctx, self) if finalizer_node := self.ts_node.child_by_field_name("finalizer"): self.finalizer = TSBlockStatement(finalizer_node, file_node_id, ctx, self.code_block) @@ -102,11 +102,3 @@ def other_possible_blocks(self) -> Sequence[ConditionalBlock]: return [self.catch] else: return [] - - @property - @noapidoc - def end_byte_for_condition_block(self) -> int: - if self.code_block: - return self.code_block.end_byte - else: - return self.end_byte diff --git a/tests/unit/codegen/sdk/python/statements/assignment_statement/__init__.py b/tests/unit/codegen/sdk/python/statements/assignment_statement/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/codegen/sdk/python/statements/attribute/__init__.py b/tests/unit/codegen/sdk/python/statements/attribute/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/__init__.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/codegen/sdk/python/statements/import_statement/__init__.py b/tests/unit/codegen/sdk/python/statements/import_statement/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py deleted file mode 100644 index 76bb5d0f4..000000000 --- a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py +++ /dev/null @@ -1,97 +0,0 @@ -from codegen.sdk.codebase.factory.get_session import get_codebase_session -from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement - - -def test_try_except_statement_parse(tmpdir) -> None: - # language=python - content = """ -try: - print(1/0) -except ZeroDivisionError as e: - print(e) - """ - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - statements = file.code_block.statements - try_except = statements[0] - assert try_except.code_block.statements[0].source == "print(1/0)" - except_clause = try_except.except_clauses[0] - assert except_clause.condition == "ZeroDivisionError as e" - assert except_clause.code_block.statements[0].source == "print(e)" - - -def test_try_except_statement_function_calls(tmpdir) -> None: - # language=python - content = """ -try: - risky_operation() -except SomeException as e: - handle_exception() - log_error(e) - """ - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - try_except = file.code_block.statements[0] - function_calls = try_except.except_clauses[0].function_calls - assert len(function_calls) == 2 - assert function_calls[0].source == "handle_exception()" - assert function_calls[1].source == "log_error(e)" - - -def test_try_except_statement_dependencies(tmpdir) -> None: - # language=python - content = """ -risky_var = 'risky' -def risky(): - try: - print(risky_var) - except NameError as e: - print("Variable not defined:", e) - """ - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - risky_function = file.get_function("risky") - dependencies = risky_function.dependencies - assert len(dependencies) == 1 - assert dependencies[0] == file.get_global_var("risky_var") - - -def test_try_except_statement_is_wrapped_in(tmpdir) -> None: - # language=python - content = """ -risky_var = 'risky' -def risky(): - call() - try: - call() - if a: - call() - except NameError as e: - pass - """ - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - risky_function = file.get_function("risky") - assert not file.function_calls[0].is_wrapped_in(TryCatchStatement) - assert file.function_calls[1].is_wrapped_in(TryCatchStatement) - assert file.function_calls[2].is_wrapped_in(TryCatchStatement) - - -def test_try_except_reassigment_handling(tmpdir) -> None: - content = """ - try: - PYSPARK = True # This gets removed even though there is a later use - except ImportError: - PYSPARK = False - - print(PYSPARK) - """ - - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - symbo = file.get_symbol("PYSPARK") - funct_call = file.function_calls[0] - pyspark_arg = funct_call.args.children[0] - for symb in file.symbols: - usage = symb.usages[0] - assert usage.match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/assignment_statement/test_assignment_statement_remove.py b/tests/unit/codegen/sdk/python/statements/test_assignment_statement_remove.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/assignment_statement/test_assignment_statement_remove.py rename to tests/unit/codegen/sdk/python/statements/test_assignment_statement_remove.py diff --git a/tests/unit/codegen/sdk/python/statements/assignment_statement/test_assignment_statement_rename.py b/tests/unit/codegen/sdk/python/statements/test_assignment_statement_rename.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/assignment_statement/test_assignment_statement_rename.py rename to tests/unit/codegen/sdk/python/statements/test_assignment_statement_rename.py diff --git a/tests/unit/codegen/sdk/python/statements/assignment_statement/test_assignment_statement_set_assignment_value.py b/tests/unit/codegen/sdk/python/statements/test_assignment_statement_set_assignment_value.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/assignment_statement/test_assignment_statement_set_assignment_value.py rename to tests/unit/codegen/sdk/python/statements/test_assignment_statement_set_assignment_value.py diff --git a/tests/unit/codegen/sdk/python/statements/attribute/test_attribute_assignment_value.py b/tests/unit/codegen/sdk/python/statements/test_attribute_assignment_value.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/attribute/test_attribute_assignment_value.py rename to tests/unit/codegen/sdk/python/statements/test_attribute_assignment_value.py diff --git a/tests/unit/codegen/sdk/python/statements/attribute/test_attribute_get_usages.py b/tests/unit/codegen/sdk/python/statements/test_attribute_get_usages.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/attribute/test_attribute_get_usages.py rename to tests/unit/codegen/sdk/python/statements/test_attribute_get_usages.py diff --git a/tests/unit/codegen/sdk/python/statements/attribute/test_attribute_properties.py b/tests/unit/codegen/sdk/python/statements/test_attribute_properties.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/attribute/test_attribute_properties.py rename to tests/unit/codegen/sdk/python/statements/test_attribute_properties.py diff --git a/tests/unit/codegen/sdk/python/statements/attribute/test_attribute_remove.py b/tests/unit/codegen/sdk/python/statements/test_attribute_remove.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/attribute/test_attribute_remove.py rename to tests/unit/codegen/sdk/python/statements/test_attribute_remove.py diff --git a/tests/unit/codegen/sdk/python/statements/attribute/test_attribute_rename.py b/tests/unit/codegen/sdk/python/statements/test_attribute_rename.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/attribute/test_attribute_rename.py rename to tests/unit/codegen/sdk/python/statements/test_attribute_rename.py diff --git a/tests/unit/codegen/sdk/python/statements/attribute/test_attribute_set_type_annotation.py b/tests/unit/codegen/sdk/python/statements/test_attribute_set_type_annotation.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/attribute/test_attribute_set_type_annotation.py rename to tests/unit/codegen/sdk/python/statements/test_attribute_set_type_annotation.py diff --git a/tests/unit/codegen/sdk/python/statements/for_loop_statement/test_for_loop_statement.py b/tests/unit/codegen/sdk/python/statements/test_for_loop_statement.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/for_loop_statement/test_for_loop_statement.py rename to tests/unit/codegen/sdk/python/statements/test_for_loop_statement.py diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_reduce_block.py b/tests/unit/codegen/sdk/python/statements/test_if_block_reduce_block.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_reduce_block.py rename to tests/unit/codegen/sdk/python/statements/test_if_block_reduce_block.py diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/test_if_block_statement_properties.py similarity index 92% rename from tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py rename to tests/unit/codegen/sdk/python/statements/test_if_block_statement_properties.py index 6e38b9a45..025486164 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/test_if_block_statement_properties.py @@ -305,3 +305,34 @@ def foo(): func = file.get_function("foo") for assign in func.valid_symbol_names[:-1]: assign.usages[0] == pyspark_arg + + +def test_if_else_reassigment_handling_double_nested(tmpdir) -> None: + content = """ + if False: + PYSPARK = "TEST1" + elif True: + PYSPARK = "TEST2" + + if True: + PYSPARK = True + elif None: + if True: + PYSPARK = True + elif None: + if True: + PYSPARK = True + elif None: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/import_statement/test_import_statement.py b/tests/unit/codegen/sdk/python/statements/test_import_statement.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/import_statement/test_import_statement.py rename to tests/unit/codegen/sdk/python/statements/test_import_statement.py diff --git a/tests/unit/codegen/sdk/python/statements/test_match_statement.py b/tests/unit/codegen/sdk/python/statements/test_match_statement.py new file mode 100644 index 000000000..e3206d177 --- /dev/null +++ b/tests/unit/codegen/sdk/python/statements/test_match_statement.py @@ -0,0 +1,213 @@ +from codegen.sdk.codebase.factory.get_session import get_codebase_session + + +def test_match_switch_statement_parse(tmpdir) -> None: + # language=python + content = """ +match 1/0: + case ZeroDivisionError as e: + print(e) + print(1) + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + statements = file.code_block.statements + match_stmt = statements[0].cases[0] + assert match_stmt.code_block.statements[0].source == "print(e)" + assert match_stmt.code_block.statements[1].source == "print(1)" + + +def test_match_switch_statement_function_calls(tmpdir) -> None: + # language=python + content = """ +match risky_operation(): + case SomeException as e: + handle_exception() + log_error(e) + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + match_stmt = file.code_block.statements[0] + case_clause = match_stmt.cases[0] + statements = case_clause.code_block.statements + assert len(statements) == 2 + assert statements[0].source == "handle_exception()" + assert statements[1].source == "log_error(e)" + + +def test_match_switch_statement_dependencies(tmpdir) -> None: + # language=python + content = """ +risky_var = 'risky' +def risky(): + match risky_var: + case NameError as e: + print("Variable not defined:", e) + case _: + print(risky_var) + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + risky_function = file.get_function("risky") + dependencies = risky_function.dependencies + assert len(dependencies) == 1 + global_var = file.get_global_var("risky_var") + assert dependencies[0] == global_var + + +def test_match_reassigment_handling(tmpdir) -> None: + content = """ +filter = 1 +match filter: + case 1: + PYSPARK=True + case 2: + PYSPARK=False + case _: + PYSPARK=None + +print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols[1:]: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_match_reassigment_handling_function(tmpdir) -> None: + content = """ +action = "create" +match action: + case "create": + def process(): + print("creating") + case "update": + def process(): + print("updating") + case _: + def process(): + print("unknown action") + +process() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + process = file.get_function("process") + funct_call = file.function_calls[3] # Skip the print calls + for func in file.functions: + usage = func.usages[0] + assert usage.match == funct_call + + +def test_match_reassigment_handling_inside_func(tmpdir) -> None: + content = """ +def get_message(status): + result = None + match status: + case "success": + result = "Operation successful" + case "error": + result = "An error occurred" + case _: + result = "Unknown status" + return result + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + get_message = file.get_function("get_message") + return_stmt = get_message.code_block.statements[-1] + result_var = return_stmt.value + for symb in file.symbols(True): + if symb.name == "result": + assert len(symb.usages) > 0 + assert any(usage.match == result_var for usage in symb.usages) + + +def test_match_reassigment_handling_nested(tmpdir) -> None: + content = """ +outer = "first" +match outer: + case "first": + RESULT = "outer first" + inner = "second" + match inner: + case "second": + RESULT = "inner second" + case _: + RESULT = "inner default" + case _: + RESULT = "outer default" + +print(RESULT) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + result_arg = funct_call.args.children[0] + for symb in file.symbols: + if symb.name == "RESULT": + usage = symb.usages[0] + assert usage.match == result_arg + + +def test_match_multiple_reassigment(tmpdir) -> None: + content = """ +first = "a" +match first: + case "a": + VALUE = "first a" + case _: + VALUE = "first default" + +second = "b" +match second: + case "b": + VALUE = "second b" + case _: + VALUE = "second default" + +print(VALUE) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + value_arg = funct_call.args.children[0] + for symb in file.symbols: + if symb.name == "VALUE": + usage = symb.usages[0] + assert usage.match == value_arg + + +def test_match_complex_pattern_reassigment(tmpdir) -> None: + content = """ +data = {"type": "user", "name": "John", "age": 30} +match data: + case {"type": "user", "name": name, "age": age} if age > 18: + STATUS = "adult user" + case {"type": "user", "name": name}: + STATUS = "user with unknown age" + case {"type": "admin"}: + STATUS = "admin" + case _: + STATUS = "unknown" + +print(STATUS) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + status_arg = funct_call.args.children[0] + for symb in file.symbols: + if symb.name == "STATUS": + usage = symb.usages[0] + assert usage.match == status_arg diff --git a/tests/unit/codegen/sdk/python/statements/test_try_catch_statement.py b/tests/unit/codegen/sdk/python/statements/test_try_catch_statement.py new file mode 100644 index 000000000..36bc1bf1c --- /dev/null +++ b/tests/unit/codegen/sdk/python/statements/test_try_catch_statement.py @@ -0,0 +1,321 @@ +from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement + + +def test_try_except_statement_parse(tmpdir) -> None: + # language=python + content = """ +try: + print(1/0) +except ZeroDivisionError as e: + print(e) + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + statements = file.code_block.statements + try_except = statements[0] + assert try_except.code_block.statements[0].source == "print(1/0)" + except_clause = try_except.except_clauses[0] + assert except_clause.condition == "ZeroDivisionError as e" + assert except_clause.code_block.statements[0].source == "print(e)" + + +def test_try_except_statement_function_calls(tmpdir) -> None: + # language=python + content = """ +try: + risky_operation() +except SomeException as e: + handle_exception() + log_error(e) + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + try_except = file.code_block.statements[0] + function_calls = try_except.except_clauses[0].function_calls + assert len(function_calls) == 2 + assert function_calls[0].source == "handle_exception()" + assert function_calls[1].source == "log_error(e)" + + +def test_try_except_statement_dependencies(tmpdir) -> None: + # language=python + content = """ +risky_var = 'risky' +def risky(): + try: + print(risky_var) + except NameError as e: + print("Variable not defined:", e) + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + risky_function = file.get_function("risky") + dependencies = risky_function.dependencies + assert len(dependencies) == 1 + assert dependencies[0] == file.get_global_var("risky_var") + + +def test_try_except_statement_is_wrapped_in(tmpdir) -> None: + # language=python + content = """ +risky_var = 'risky' +def risky(): + call() + try: + call() + if a: + call() + except NameError as e: + pass + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + risky_function = file.get_function("risky") + assert not file.function_calls[0].is_wrapped_in(TryCatchStatement) + assert file.function_calls[1].is_wrapped_in(TryCatchStatement) + assert file.function_calls[2].is_wrapped_in(TryCatchStatement) + + +def test_try_except_reassigment_handling(tmpdir) -> None: + content = """ + try: + PYSPARK = True # This gets removed even though there is a later use + except ImportError: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_try_except_reassigment_handling_function(tmpdir) -> None: + content = """ + try: + def process(): + print('try') + except ImportError: + def process(): + print('except') + finally: + def process(): + print('finally') + + process() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + process = file.get_function("process") + funct_call = file.function_calls[3] # Skip the print calls + for idx, func in enumerate(file.functions): + if idx == 2: + usage = func.usages[0] + assert usage.match == funct_call + else: + assert not func.usages + + +def test_try_except_reassigment_handling_inside_func(tmpdir) -> None: + content = """ + def get_result(): + result = None + try: + result = "success" + except Exception: + result = "error" + finally: + result = "done" + return result + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + get_result = file.get_function("get_result") + return_stmt = get_result.code_block.statements[-1] + result_var = return_stmt.value + for idx, symb in enumerate(file.symbols(True)): + if symb.name == "result": + if idx == 4: + # The only usage is in the finally block + assert len(symb.usages) > 0 + assert any(usage.match == result_var for usage in symb.usages) + else: + assert len(symb.usages) == 0 + + +def test_try_except_reassigment_handling_nested(tmpdir) -> None: + content = """ + try: + RESULT = "outer try" + try: + RESULT = "inner try" + except Exception as e: + RESULT = "inner except" + except Exception as e: + RESULT = "outer except" + + print(RESULT) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + result_arg = funct_call.args.children[0] + for symb in file.symbols: + if symb.name == "RESULT": + usage = symb.usages[0] + assert usage.match == result_arg + + +def test_try_except_reassigment_with_finally(tmpdir) -> None: + content = """ + try: + STATUS = "trying" + except Exception: + STATUS = "error" + finally: + STATUS = "done" + + print(STATUS) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + status_arg = funct_call.args.children[0] + for idx, symb in enumerate(file.symbols(True)): + if symb.name == "STATUS": + if idx == 2: + # The only usage is in the finally block + assert len(symb.usages) > 0 + assert any(usage.match == status_arg for usage in symb.usages) + else: + assert len(symb.usages) == 0 + + +def test_try_except_reassigment_with_finally_nested(tmpdir) -> None: + content = """ + try: + STATUS = "trying" + except Exception: + STATUS = "error" + try: + STATUS = "trying" + except Exception: + STATUS = "error" + finally: + STATUS = "done" + + print(STATUS) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + status_arg = funct_call.args.children[0] + for idx, symb in enumerate(file.symbols(True)): + if symb.name == "STATUS": + if idx == 0 or idx == 4: + # The only usage is in the finally block + assert len(symb.usages) > 0 + assert any(usage.match == status_arg for usage in symb.usages) + else: + assert len(symb.usages) == 0 + + +def test_try_except_reassigment_with_finally_nested_deeper(tmpdir) -> None: + content = """ + try: + STATUS = "trying" + except Exception: + STATUS = "error" + try: + STATUS = "trying_lvl2" + except Exception: + STATUS = "error_lvl2" + finally: + try: + STATUS = "trying_lvl3" + except Exception: + STATUS = "error_lvl3" + finally: + STATUS = "done_lvl3" + + print(STATUS) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + status_arg = funct_call.args.children[0] + for idx, symb in enumerate(file.symbols(True)): + if symb.name == "STATUS": + if idx == 0 or idx == 6: + # The only usage is in the finally block + assert len(symb.usages) > 0 + assert any(usage.match == status_arg for usage in symb.usages) + else: + assert len(symb.usages) == 0 + + +def test_try_except_reassigment_with_finally_secondary_nested_deeper(tmpdir) -> None: + content = """ + try: + STATUS = "trying" + except Exception: + STATUS = "error" + try: + STATUS = "trying_lvl2" + except Exception: + STATUS = "error_lvl2" + finally: + try: + STATUS = "trying_lvl3" + except Exception: + STATUS = "error_lvl3" + + print(STATUS) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + status_arg = funct_call.args.children[0] + for idx, symb in enumerate(file.symbols(True)): + if symb.name == "STATUS": + assert len(symb.usages) > 0 + assert any(usage.match == status_arg for usage in symb.usages) + + +def test_try_except_multiple_reassigment(tmpdir) -> None: + content = """ + try: + VALUE = "first try" + except Exception: + VALUE = "first except" + + try: + VALUE = "second try" + except Exception: + VALUE = "second except" + + print(VALUE) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + value_arg = funct_call.args.children[0] + for symb in file.symbols: + if symb.name == "VALUE": + usage = symb.usages[0] + assert usage.match == value_arg diff --git a/tests/unit/codegen/sdk/python/statements/while_statement/test_while_statement.py b/tests/unit/codegen/sdk/python/statements/test_while_statement.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/while_statement/test_while_statement.py rename to tests/unit/codegen/sdk/python/statements/test_while_statement.py diff --git a/tests/unit/codegen/sdk/python/statements/with_statement/test_with_statement_properties.py b/tests/unit/codegen/sdk/python/statements/test_with_statement_properties.py similarity index 100% rename from tests/unit/codegen/sdk/python/statements/with_statement/test_with_statement_properties.py rename to tests/unit/codegen/sdk/python/statements/test_with_statement_properties.py diff --git a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py deleted file mode 100644 index be972cffd..000000000 --- a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py +++ /dev/null @@ -1,79 +0,0 @@ -from codegen.sdk.codebase.factory.get_session import get_codebase_session - - -def test_match_switch_statement_parse(tmpdir) -> None: - # language=python - content = """ -match 1/0: - case ZeroDivisionError as e: - print(e) - print(1) - """ - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - statements = file.code_block.statements - match_stmt = statements[0].cases[0] - assert match_stmt.code_block.statements[0].source == "print(e)" - assert match_stmt.code_block.statements[1].source == "print(1)" - - -def test_match_switch_statement_function_calls(tmpdir) -> None: - # language=python - content = """ -match risky_operation(): - case SomeException as e: - handle_exception() - log_error(e) - """ - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - match_stmt = file.code_block.statements[0] - case_clause = match_stmt.cases[0] - statements = case_clause.code_block.statements - assert len(statements) == 2 - assert statements[0].source == "handle_exception()" - assert statements[1].source == "log_error(e)" - - -def test_match_switch_statement_dependencies(tmpdir) -> None: - # language=python - content = """ -risky_var = 'risky' -def risky(): - match risky_var: - case NameError as e: - print("Variable not defined:", e) - case _: - print(risky_var) - """ - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - risky_function = file.get_function("risky") - dependencies = risky_function.dependencies - assert len(dependencies) == 1 - global_var = file.get_global_var("risky_var") - assert dependencies[0] == global_var - - -def test_match_reassigment_handling(tmpdir) -> None: - content = """ -filter = 1 -match filter: - case 1: - PYSPARK=True - case 2: - PYSPARK=False - case _: - PYSPARK=None - -print(PYSPARK) - """ - - with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: - file = codebase.get_file("test.py") - symbo = file.get_symbol("PYSPARK") - funct_call = file.function_calls[0] - pyspark_arg = funct_call.args.children[0] - for symb in file.symbols[1:]: - usage = symb.usages[0] - assert usage.match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/with_statement/__init__.py b/tests/unit/codegen/sdk/python/statements/with_statement/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_statement_properties.py index 216741d61..2a09ff205 100644 --- a/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_statement_properties.py @@ -140,3 +140,193 @@ def test_get_alternative_if_blocks_from_codeblock(tmpdir) -> None: assert len(alt_blocks[2].alternative_blocks) == 0 assert len(alt_blocks[2].elif_statements) == 0 assert alt_blocks[2].else_statement is None + + +def test_if_else_reassignment_handling(tmpdir) -> None: + # language=typescript + content = """ +if (true) { + PYSPARK = true; +} else if (false) { + PYSPARK = false; +} else { + PYSPARK = null; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + symbol = file.get_symbol("PYSPARK") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassignment_handling_function(tmpdir) -> None: + # language=typescript + content = """ +if (true) { + function foo() { + console.log('t'); + } +} else if (false) { + function foo() { + console.log('t'); + } +} else { + function foo() { + console.log('t'); + } +} +foo(); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + foo = file.get_function("foo") + func_call = file.function_calls[3] + for func in file.functions: + usage = func.usages[0] + assert usage.match == func_call + + +def test_if_else_reassignment_handling_inside_func(tmpdir) -> None: + # language=typescript + content = """ +function foo(a) { + a = 1; + if (xyz) { + b = 1; + } else { + b = 2; + } + f(a); // a resolves to 1 name + f(b); // b resolves to 2 possible names +} + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + foo = file.get_function("foo") + assert foo + assert len(foo.parameters[0].usages) == 0 + func_call_a = foo.function_calls[0].args[0] + func_call_b = foo.function_calls[1] + for symbol in file.symbols(True): + if symbol.name == "a": + assert len(symbol.usages) == 1 + symbol.usages[0].match == func_call_a + elif symbol.name == "b": + assert len(symbol.usages) == 1 + symbol.usages[0].match == func_call_b + + +def test_if_else_reassignment_handling_partial_if(tmpdir) -> None: + # language=typescript + content = """ +PYSPARK = "TEST"; +if (true) { + PYSPARK = true; +} else if (null) { + PYSPARK = false; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + symbol = file.get_symbol("PYSPARK") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassignment_handling_solo_if(tmpdir) -> None: + # language=typescript + content = """ +PYSPARK = "TEST"; +if (true) { + PYSPARK = true; +} +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + symbol = file.get_symbol("PYSPARK") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassignment_handling_double(tmpdir) -> None: + # language=typescript + content = """ +if (false) { + PYSPARK = "TEST1"; +} else if (true) { + PYSPARK = "TEST2"; +} + +if (true) { + PYSPARK = true; +} else if (null) { + PYSPARK = false; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + symbol = file.get_symbol("PYSPARK") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassignment_handling_nested_usage(tmpdir) -> None: + # language=typescript + content = """ +if (true) { + PYSPARK = true; +} else if (null) { + PYSPARK = false; + console.log(PYSPARK); +} + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + first = file.symbols[0] + second = file.symbols[1] + assert len(first.usages) == 0 + assert second.usages[0].match == pyspark_arg + + +def test_if_else_reassignment_inside_func_with_external_element(tmpdir) -> None: + # language=typescript + content = """ +PYSPARK = "0"; +function foo() { + if (true) { + PYSPARK = true; + } else { + PYSPARK = false; + } + console.log(PYSPARK); +} + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + func = file.get_function("foo") + for assign in func.valid_symbol_names[:-1]: + assign.usages[0] == pyspark_arg diff --git a/tests/unit/codegen/sdk/typescript/statements/switch_statement/test_switch_statement.py b/tests/unit/codegen/sdk/typescript/statements/switch_statement/test_switch_statement.py index 6ac3a94af..5666524bd 100644 --- a/tests/unit/codegen/sdk/typescript/statements/switch_statement/test_switch_statement.py +++ b/tests/unit/codegen/sdk/typescript/statements/switch_statement/test_switch_statement.py @@ -89,3 +89,186 @@ def test_switch_statement_dependencies(tmpdir) -> None: selectFruit = file.get_function("selectFruit") assert len(selectFruit.dependencies) == 1 assert selectFruit.dependencies[0] == file.get_global_var("fruit") + + +def test_switch_reassignment_handling(tmpdir) -> None: + # language=typescript + content = """ +const filter = 1; +switch (filter) { + case 1: + PYSPARK = true; + break; + case 2: + PYSPARK = false; + break; + default: + PYSPARK = null; + break; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + symbol = file.get_symbol("PYSPARK") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols[1:]: # Skip the first symbol which is 'filter' + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_switch_reassignment_handling_function(tmpdir) -> None: + # language=typescript + content = """ +const type = "handler"; +switch (type) { + case "handler": + function process(){ return "handler"; } + break; + case "processor": + function process(){ return "processor"; } + break; + default: + function process(){ return "default"; } + break; +} +process(); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + process = file.get_function("process") + func_call = file.function_calls[0] + for func in file.functions: + if func.name == "process": + usage = func.usages[0] + assert usage.match == func_call + + +def test_switch_reassignment_handling_inside_func(tmpdir) -> None: + # language=typescript + content = """ +function getStatus(code) { + let status; + switch (code) { + case 200: + status = "OK"; + break; + case 404: + status = "Not Found"; + break; + default: + status = "Unknown"; + break; + } + return status; +} + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + get_status = file.get_function("getStatus") + return_stmt = get_status.code_block.statements[-1] + status_var = return_stmt.value + for symb in file.symbols(True): + if symb.name == "status": + assert len(symb.usages) > 0 + assert any(usage.match == status_var for usage in symb.usages) + + +def test_switch_reassignment_handling_nested(tmpdir) -> None: + # language=typescript + content = """ +const outer = 1; +switch (outer) { + case 1: + RESULT = "outer 1"; + const inner = 2; + switch (inner) { + case 2: + RESULT = "inner 2"; + break; + default: + RESULT = "inner default"; + break; + } + break; + default: + RESULT = "outer default"; + break; +} +console.log(RESULT); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + result_arg = func_call.args.children[0] + for symb in file.symbols: + if symb.name == "RESULT": + usage = symb.usages[0] + assert usage.match == result_arg + + +def test_switch_multiple_reassignment(tmpdir) -> None: + # language=typescript + content = """ +const first = 1; +switch (first) { + case 1: + VALUE = "first 1"; + break; + default: + VALUE = "first default"; + break; +} + +const second = 2; +switch (second) { + case 2: + VALUE = "second 2"; + break; + default: + VALUE = "second default"; + break; +} + +console.log(VALUE); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + value_arg = func_call.args.children[0] + for symb in file.symbols: + if symb.name == "VALUE": + usage = symb.usages[0] + assert usage.match == value_arg + + +def test_switch_fallthrough_reassignment(tmpdir) -> None: + # language=typescript + content = """ +const code = 1; +switch (code) { + case 1: + STATUS = "Processing"; + // Fallthrough intentional + case 2: + STATUS = "Almost done"; + // Fallthrough intentional + case 3: + STATUS = "Complete"; + break; + default: + STATUS = "Unknown"; + break; +} +console.log(STATUS); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + status_arg = func_call.args.children[0] + for symb in file.symbols: + if symb.name == "STATUS": + usage = symb.usages[0] + assert usage.match == status_arg diff --git a/tests/unit/codegen/sdk/typescript/statements/try_catch_statement/test_try_catch_statement.py b/tests/unit/codegen/sdk/typescript/statements/try_catch_statement/test_try_catch_statement.py index b7ddfa236..df3ec0212 100644 --- a/tests/unit/codegen/sdk/typescript/statements/try_catch_statement/test_try_catch_statement.py +++ b/tests/unit/codegen/sdk/typescript/statements/try_catch_statement/test_try_catch_statement.py @@ -77,3 +77,212 @@ def test_try_catch_statement_dependencies(tmpdir) -> None: example = file.get_function("example") assert len(example.dependencies) == 1 assert example.dependencies[0] == file.get_global_var("globalVar") + + +def test_try_catch_statement_dependencies_external(tmpdir) -> None: + # language=typescript + content = """ +let test: Record | undefined; +try { + test = JSON.parse(test); + if (Object.keys(test).length === 0) { + test = undefined; + } +} catch (e) { + console.error("Error parsing test", e); + test = undefined; +} +let use = test + + + + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + symbol = file.get_symbol("test") + func_call = file.function_calls[0] + assign_val = file.symbols[-1].value + for symb in file.symbols[:-1]: + assert any(usage.match == assign_val for usage in symb.usages) + + +def test_try_catch_reassignment_handling(tmpdir) -> None: + # language=typescript + content = """ +try { + PYSPARK = true; +} catch (error) { + PYSPARK = false; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + symbol = file.get_symbol("PYSPARK") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_try_catch_reassignment_handling_function(tmpdir) -> None: + # language=typescript + content = """ +try { + function foo() { + console.log('try'); + } +} catch (error) { + function foo() { + console.log('catch'); + } +} +foo(); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + foo = file.get_function("foo") + func_call = file.function_calls[2] + for func in file.functions: + assert func.usages + usage = func.usages[0] + assert usage.match == func_call + + +def test_try_catch_reassignment_handling_function_finally(tmpdir) -> None: + # language=typescript + content = """ +try { + function foo() { + console.log('try'); + } +} catch (error) { + function foo() { + console.log('catch'); + } +} finally { + function foo() { + console.log('finally'); + } +} +foo(); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + foo = file.get_function("foo") + func_call = file.function_calls[3] + for idx, func in enumerate(file.functions): + if idx == 2: + assert func.usages + usage = func.usages[0] + assert usage.match == func_call + else: + assert len(func.usages) == 0 + + +def test_try_catch_reassignment_handling_nested(tmpdir) -> None: + # language=typescript + content = """ +try { + PYSPARK = true; + try { + PYSPARK = "nested"; + } catch (innerError) { + PYSPARK = "inner catch"; + } +} catch (error) { + PYSPARK = false; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_try_catch_reassignment_handling_inside_func(tmpdir) -> None: + # language=typescript + content = """ +function process() { + let result; + try { + result = "success"; + } catch (error) { + result = "error"; + } finally { + result = "done"; + } + return result; +} + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + process_func = file.get_function("process") + return_stmt = process_func.code_block.statements[-1] + result_var = return_stmt.value + for idx, symb in enumerate(file.symbols(True)): + if symb.name == "result": + if idx == 4: + # Only finally triggers + assert len(symb.usages) > 0 + assert any(usage.match == result_var for usage in symb.usages) + else: + assert len(symb.usages) == 0 + + +def test_try_catch_reassignment_with_finally(tmpdir) -> None: + # language=typescript + content = """ +try { + PYSPARK = true; +} catch (error) { + PYSPARK = false; +} finally { + PYSPARK = "finally"; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for idx, symb in enumerate(file.symbols): + if idx == 2: + usage = symb.usages[0] + assert usage.match == pyspark_arg + else: + assert not symb.usages + + +def test_try_catch_multiple_reassignment(tmpdir) -> None: + # language=typescript + content = """ +try { + PYSPARK = true; +} catch (error) { + PYSPARK = false; +} + +try { + PYSPARK = "second try"; +} catch (error) { + PYSPARK = "second catch"; +} + +console.log(PYSPARK); + """ + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file("test.ts") + func_call = file.function_calls[0] + pyspark_arg = func_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg