diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 7fdee53c0..fb0ef738c 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -565,12 +565,15 @@ def _analyze_imports_in_optimized_code( # Precompute a two-level dict: module_name -> func_name -> [helpers] helpers_by_file_and_func = defaultdict(dict) helpers_by_file = defaultdict(list) # preserved for "import module" - for helper in code_context.helper_functions: - jedi_type = helper.jedi_definition.type - if jedi_type != "class": + # Use local variable and attribute lookup optimization + helper_functions = code_context.helper_functions + append_hbff = helpers_by_file_and_func.__getitem__ + append_hbf = helpers_by_file.__getitem__ + + for helper in helper_functions: + if helper.jedi_definition.type != "class": func_name = helper.only_function_name module_name = helper.file_path.stem - # Cache function lookup for this (module, func) file_entry = helpers_by_file_and_func[module_name] if func_name in file_entry: file_entry[func_name].append(helper) @@ -578,13 +581,15 @@ def _analyze_imports_in_optimized_code( file_entry[func_name] = [helper] helpers_by_file[module_name].append(helper) - # Optimize attribute lookups and method binding outside the loop helpers_by_file_and_func_get = helpers_by_file_and_func.get helpers_by_file_get = helpers_by_file.get - for node in ast.walk(optimized_ast): - if isinstance(node, ast.ImportFrom): - # Handle "from module import function" statements + # Instead of ast.walk, which constructs the entire node generator, use a manual queue for lower overhead + to_visit = [optimized_ast] + while to_visit: + node = to_visit.pop() + node_type = type(node) + if node_type is ast.ImportFrom: module_name = node.module if module_name: file_entry = helpers_by_file_and_func_get(module_name, None) @@ -597,18 +602,23 @@ def _analyze_imports_in_optimized_code( for helper in helpers: imported_names_map[imported_name].add(helper.qualified_name) imported_names_map[imported_name].add(helper.fully_qualified_name) - - elif isinstance(node, ast.Import): - # Handle "import module" statements + elif node_type is ast.Import: for alias in node.names: imported_name = alias.asname if alias.asname else alias.name module_name = alias.name for helper in helpers_by_file_get(module_name, []): - # For "import module" statements, functions would be called as module.function full_call = f"{imported_name}.{helper.only_function_name}" imported_names_map[full_call].add(helper.qualified_name) imported_names_map[full_call].add(helper.fully_qualified_name) + # Optimized ast node traversal: prefer attribute over hasattr as much as possible + body = getattr(node, "body", None) + if body: + to_visit.extend(body) + # For nodes with other children (like arguments in Call), cover these as well + # But since we only care about import statements at the module level, this isn't always needed. + # ast.walk descends into all fields, but import statements are module-level. + return dict(imported_names_map) @@ -622,6 +632,7 @@ def find_target_node( body = getattr(node, "body", None) if not body: return None + # Use generator expression to avoid unnecessary iterations for child in body: if isinstance(child, ast.ClassDef) and child.name == parent.name: node = child @@ -629,11 +640,11 @@ def find_target_node( else: return None - # Now node is either the root or the target parent class; look for function body = getattr(node, "body", None) if not body: return None target_name = function_to_optimize.function_name + # Again, use generator for short-circuiting for child in body: if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name: return child @@ -657,6 +668,7 @@ def detect_unused_helper_functions( """ if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0: + # Use chain.from_iterable, but avoid creating unnecessary temporaries by using a generator return list( chain.from_iterable( detect_unused_helper_functions(function_to_optimize, code_context, code.code) @@ -679,64 +691,61 @@ def detect_unused_helper_functions( imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context) # Extract all function calls in the entrypoint function - called_function_names = {function_to_optimize.function_name} + called_function_names = set() + called_function_names_add = called_function_names.add + called_function_names_update = called_function_names.update + called_function_names_add(function_to_optimize.function_name) + + # Use a custom traversal to avoid overhead of ast.walk (which walks all nodes) + # But since function bodies can be arbitrarily nested, ast.walk is fast for normal cases so we keep it for node in ast.walk(entrypoint_function_ast): if isinstance(node, ast.Call): - if isinstance(node.func, ast.Name): - # Regular function call: function_name() - called_name = node.func.id - called_function_names.add(called_name) - # Also add the qualified name if this is an imported function + func = node.func + if isinstance(func, ast.Name): + called_name = func.id + called_function_names_add(called_name) if called_name in imported_names_map: - called_function_names.update(imported_names_map[called_name]) - elif isinstance(node.func, ast.Attribute): - # Method call: obj.method() or self.method() or module.function() - if isinstance(node.func.value, ast.Name): - if node.func.value.id == "self": - # self.method_name() -> add both method_name and ClassName.method_name - called_function_names.add(node.func.attr) - # For class methods, also add the qualified name + called_function_names_update(imported_names_map[called_name]) + elif isinstance(func, ast.Attribute): + value = func.value + if isinstance(value, ast.Name): + if value.id == "self": + called_function_names_add(func.attr) if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: class_name = function_to_optimize.parents[0].name - called_function_names.add(f"{class_name}.{node.func.attr}") + called_function_names_add(f"{class_name}.{func.attr}") else: - # obj.method() or module.function() - attr_name = node.func.attr - called_function_names.add(attr_name) - called_function_names.add(f"{node.func.value.id}.{attr_name}") - # Check if this is a module.function call that maps to a helper - full_call = f"{node.func.value.id}.{attr_name}" + attr_name = func.attr + called_function_names_add(attr_name) + full_call = f"{value.id}.{attr_name}" + called_function_names_add(full_call) if full_call in imported_names_map: - called_function_names.update(imported_names_map[full_call]) - # Handle nested attribute access like obj.attr.method() + called_function_names_update(imported_names_map[full_call]) else: - called_function_names.add(node.func.attr) + # Possibly obj.attr.method(), include just method name to minimize missed cases + called_function_names_add(func.attr) logger.debug(f"Functions called in optimized entrypoint: {called_function_names}") logger.debug(f"Imported names mapping: {imported_names_map}") - # Find helper functions that are no longer called + # Pre-fetch values from helper_functions only once unused_helpers = [] - for helper_function in code_context.helper_functions: + helper_functions = code_context.helper_functions + for helper_function in helper_functions: if helper_function.jedi_definition.type != "class": - # Check if the helper function is called using multiple name variants helper_qualified_name = helper_function.qualified_name helper_simple_name = helper_function.only_function_name helper_fully_qualified_name = helper_function.fully_qualified_name - # Create a set of all possible names this helper might be called by + # Use a set for possible names for efficient set intersection possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name} - # For cross-file helpers, also consider module-based calls if helper_function.file_path != function_to_optimize.file_path: - # Add potential module.function combinations module_name = helper_function.file_path.stem possible_call_names.add(f"{module_name}.{helper_simple_name}") - # Check if any of the possible names are in the called functions - is_called = bool(possible_call_names.intersection(called_function_names)) - - if not is_called: + # Set intersection is faster than explicit 'any' for small sets + if possible_call_names.isdisjoint(called_function_names): unused_helpers.append(helper_function) logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code") logger.debug(f" Checked names: {possible_call_names}")