diff --git a/codeflash/code_utils/deduplicate_code.py b/codeflash/code_utils/deduplicate_code.py new file mode 100644 index 000000000..d0f9f3271 --- /dev/null +++ b/codeflash/code_utils/deduplicate_code.py @@ -0,0 +1,247 @@ +import ast +import hashlib +from typing import Dict, Set + + +class VariableNormalizer(ast.NodeTransformer): + """Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc. + Preserves function names, class names, parameters, built-ins, and imported names. + """ + + def __init__(self): + self.var_counter = 0 + self.var_mapping: Dict[str, str] = {} + self.scope_stack = [] + self.builtins = set(dir(__builtins__)) + self.imports: Set[str] = set() + self.global_vars: Set[str] = set() + self.nonlocal_vars: Set[str] = set() + self.parameters: Set[str] = set() # Track function parameters + + def enter_scope(self): + """Enter a new scope (function/class)""" + self.scope_stack.append( + {"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)} + ) + + def exit_scope(self): + """Exit current scope and restore parent scope""" + if self.scope_stack: + scope = self.scope_stack.pop() + self.var_mapping = scope["var_mapping"] + self.var_counter = scope["var_counter"] + self.parameters = scope["parameters"] + + def get_normalized_name(self, name: str) -> str: + """Get or create normalized name for a variable""" + # Don't normalize if it's a builtin, import, global, nonlocal, or parameter + if ( + name in self.builtins + or name in self.imports + or name in self.global_vars + or name in self.nonlocal_vars + or name in self.parameters + ): + return name + + # Only normalize local variables + if name not in self.var_mapping: + self.var_mapping[name] = f"var_{self.var_counter}" + self.var_counter += 1 + return self.var_mapping[name] + + def visit_Import(self, node): + """Track imported names""" + for alias in node.names: + name = alias.asname if alias.asname else alias.name + self.imports.add(name.split(".")[0]) + return node + + def visit_ImportFrom(self, node): + """Track imported names from modules""" + for alias in node.names: + name = alias.asname if alias.asname else alias.name + self.imports.add(name) + return node + + def visit_Global(self, node): + """Track global variable declarations""" + # Avoid repeated .add calls by using set.update with list + self.global_vars.update(node.names) + return node + + def visit_Nonlocal(self, node): + """Track nonlocal variable declarations""" + for name in node.names: + self.nonlocal_vars.add(name) + return node + + def visit_FunctionDef(self, node): + """Process function but keep function name and parameters unchanged""" + self.enter_scope() + + # Track all parameters (don't modify them) + for arg in node.args.args: + self.parameters.add(arg.arg) + if node.args.vararg: + self.parameters.add(node.args.vararg.arg) + if node.args.kwarg: + self.parameters.add(node.args.kwarg.arg) + for arg in node.args.kwonlyargs: + self.parameters.add(arg.arg) + + # Visit function body + node = self.generic_visit(node) + self.exit_scope() + return node + + def visit_AsyncFunctionDef(self, node): + """Handle async functions same as regular functions""" + return self.visit_FunctionDef(node) + + def visit_ClassDef(self, node): + """Process class but keep class name unchanged""" + self.enter_scope() + node = self.generic_visit(node) + self.exit_scope() + return node + + def visit_Name(self, node): + """Normalize variable names in Name nodes""" + if isinstance(node.ctx, (ast.Store, ast.Del)): + # For assignments and deletions, check if we should normalize + if ( + node.id not in self.builtins + and node.id not in self.imports + and node.id not in self.parameters + and node.id not in self.global_vars + and node.id not in self.nonlocal_vars + ): + node.id = self.get_normalized_name(node.id) + elif isinstance(node.ctx, ast.Load): + # For loading, use existing mapping if available + if node.id in self.var_mapping: + node.id = self.var_mapping[node.id] + return node + + def visit_ExceptHandler(self, node): + """Normalize exception variable names""" + if node.name: + node.name = self.get_normalized_name(node.name) + return self.generic_visit(node) + + def visit_comprehension(self, node): + """Normalize comprehension target variables""" + # Create new scope for comprehension + old_mapping = dict(self.var_mapping) + old_counter = self.var_counter + + # Process the comprehension + node = self.generic_visit(node) + + # Restore scope + self.var_mapping = old_mapping + self.var_counter = old_counter + return node + + def visit_For(self, node): + """Handle for loop target variables""" + # The target in a for loop is a local variable that should be normalized + return self.generic_visit(node) + + def visit_With(self, node): + """Handle with statement as variables""" + return self.generic_visit(node) + + +def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: + """Normalize Python code by parsing, cleaning, and normalizing only variable names. + Function names, class names, and parameters are preserved. + + Args: + code: Python source code as string + remove_docstrings: Whether to remove docstrings + + Returns: + Normalized code as string + + """ + try: + # Parse the code + tree = ast.parse(code) + + # Remove docstrings if requested + if remove_docstrings: + remove_docstrings_from_ast(tree) + + # Normalize variable names + normalizer = VariableNormalizer() + normalized_tree = normalizer.visit(tree) + if return_ast_dump: + # This is faster than unparsing etc + return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False) + + # Fix missing locations in the AST + ast.fix_missing_locations(normalized_tree) + + # Unparse back to code + return ast.unparse(normalized_tree) + except SyntaxError as e: + msg = f"Invalid Python syntax: {e}" + raise ValueError(msg) from e + + +def remove_docstrings_from_ast(node): + """Remove docstrings from AST nodes.""" + # Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0] + node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module) + # Use our own stack-based DFS instead of ast.walk for efficiency + stack = [node] + while stack: + current_node = stack.pop() + if isinstance(current_node, node_types): + # Remove docstring if it's the first stmt in body + body = current_node.body + if ( + body + and isinstance(body[0], ast.Expr) + and isinstance(body[0].value, ast.Constant) + and isinstance(body[0].value.value, str) + ): + current_node.body = body[1:] + # Only these nodes can nest more docstring-containing nodes + # Add their body elements to stack, avoiding unnecessary traversal + stack.extend([child for child in body if isinstance(child, node_types)]) + + +def get_code_fingerprint(code: str) -> str: + """Generate a fingerprint for normalized code. + + Args: + code: Python source code + + Returns: + SHA-256 hash of normalized code + + """ + normalized = normalize_code(code) + return hashlib.sha256(normalized.encode()).hexdigest() + + +def are_codes_duplicate(code1: str, code2: str) -> bool: + """Check if two code segments are duplicates after normalization. + + Args: + code1: First code segment + code2: Second code segment + + Returns: + True if codes are structurally identical (ignoring local variable names) + + """ + try: + normalized1 = normalize_code(code1, return_ast_dump=True) + normalized2 = normalize_code(code2, return_ast_dump=True) + return normalized1 == normalized2 + except Exception: + return False diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 8417148ef..e91bba3c6 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -558,7 +558,7 @@ def unique_invocation_loop_id(self) -> str: return f"{self.loop_index}:{self.id.id()}" -class TestResults(BaseModel): # noqa: PLW1641 +class TestResults(BaseModel): # don't modify these directly, use the add method # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ba65d3644..30d6d4022 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -48,6 +48,7 @@ REPEAT_OPTIMIZATION_PROBABILITY, TOTAL_LOOPING_TIME, ) +from codeflash.code_utils.deduplicate_code import normalize_code from codeflash.code_utils.edit_generated_tests import ( add_runtime_comments_to_generated_tests, remove_functions_from_generated_tests, @@ -519,7 +520,7 @@ def determine_best_candidate( ) continue # check if this code has been evaluated before by checking the ast normalized code string - normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip())) + normalized_code = normalize_code(candidate.source_code.flat.strip()) if normalized_code in ast_code_to_id: logger.info( "Current candidate has been encountered before in testing, Skipping optimization candidate." @@ -669,7 +670,7 @@ def determine_best_candidate( diff_strs = [] runtimes_list = [] for valid_opt in valid_optimizations: - valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip())) + valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) new_candidate_with_shorter_code = OptimizedCandidate( source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, diff --git a/tests/test_code_deduplication.py b/tests/test_code_deduplication.py new file mode 100644 index 000000000..deea25f93 --- /dev/null +++ b/tests/test_code_deduplication.py @@ -0,0 +1,135 @@ +from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code + + +def test_deduplicate1(): + # Example usage and tests + # Example 1: Same logic, different variable names (should NOT match due to different function/param names) + code1 = """ +def compute_sum(numbers): + '''Calculate sum of numbers''' + total = 0 + for num in numbers: + total += num + return total +""" + + code2 = """ +def compute_sum(numbers): + # This computes the sum + result = 0 + for value in numbers: + result += value + return result +""" + + assert normalize_code(code1) == normalize_code(code2) + assert are_codes_duplicate(code1, code2) + + # Example 3: Same function and parameter names, different local variables (should match) + code3 = """ +def calculate_sum(numbers): + accumulator = 0 + for item in numbers: + accumulator += item + return accumulator +""" + + code4 = """ +def calculate_sum(numbers): + total = 0 + for num in numbers: + total += num + return total +""" + + assert normalize_code(code3) == normalize_code(code4) + assert are_codes_duplicate(code3, code4) + + # Example 4: Nested functions and classes (preserving names) + code5 = """ +class DataProcessor: + def __init__(self, data): + self.data = data + + def process(self): + def helper(item): + temp = item * 2 + return temp + + results = [] + for element in self.data: + results.append(helper(element)) + return results +""" + + code6 = """ +class DataProcessor: + def __init__(self, data): + self.data = data + + def process(self): + def helper(item): + x = item * 2 + return x + + output = [] + for thing in self.data: + output.append(helper(thing)) + return output +""" + + assert normalize_code(code5) == normalize_code(code6) + + # Example 5: With imports and built-ins (these should be preserved) + code7 = """ +import math + +def calculate_circle_area(radius): + pi_value = math.pi + area = pi_value * radius ** 2 + return area +""" + + code8 = """ +import math + +def calculate_circle_area(radius): + constant = math.pi + result = constant * radius ** 2 + return result +""" + code85 = """ +import math + +def calculate_circle_area(radius): + constant = math.pi + result = constant *2 * radius ** 2 + return result +""" + + assert normalize_code(code7) == normalize_code(code8) + assert normalize_code(code8) != normalize_code(code85) + + # Example 6: Exception handling + code9 = """ +def safe_divide(a, b): + try: + result = a / b + return result + except ZeroDivisionError as e: + error_msg = str(e) + return None +""" + + code10 = """ +def safe_divide(a, b): + try: + output = a / b + return output + except ZeroDivisionError as exc: + message = str(exc) + return None +""" + assert normalize_code(code9) == normalize_code(code10) + + assert normalize_code(code9) != normalize_code(code8)