diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index bcd310b3c..a15bbbd02 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -551,53 +551,53 @@ def _analyze_imports_in_optimized_code( """ imported_names_map = defaultdict(set) - # Precompute a two-level dict: module_name -> func_name -> [helpers] - helpers_by_file_and_func = defaultdict(dict) - helpers_by_file = defaultdict(list) # preserved for "import module" - helpers_append = helpers_by_file_and_func.setdefault + # Prepare one-pass lookup: module_name -> func_name -> [helpers], and module_name -> [helpers] + helpers_by_file_and_func = {} + helpers_by_file = {} for helper in code_context.helper_functions: jedi_type = helper.jedi_definition.type - if jedi_type != "class": - func_name = helper.only_function_name - module_name = helper.file_path.stem - # Cache function lookup for this (module, func) - file_entry = helpers_by_file_and_func[module_name] - if func_name in file_entry: - file_entry[func_name].append(helper) - else: - file_entry[func_name] = [helper] - helpers_by_file[module_name].append(helper) - - # Optimize attribute lookups and method binding outside the loop + if jedi_type == "class": + continue + func_name = helper.only_function_name + module_name = helper.file_path.stem + file_entry = helpers_by_file_and_func.setdefault(module_name, {}) + file_entry.setdefault(func_name, []).append(helper) + helpers_by_file.setdefault(module_name, []).append(helper) + + # Optimize lookups: create shortcut functions helpers_by_file_and_func_get = helpers_by_file_and_func.get helpers_by_file_get = helpers_by_file.get - for node in ast.walk(optimized_ast): + # Only walk once for imports, use a generator for both Import and ImportFrom + nodes = [n for n in ast.iter_child_nodes(optimized_ast) if isinstance(n, (ast.Import, ast.ImportFrom))] + for node in nodes: if isinstance(node, ast.ImportFrom): - # Handle "from module import function" statements module_name = node.module if module_name: - file_entry = helpers_by_file_and_func_get(module_name, None) + file_entry = helpers_by_file_and_func_get(module_name) if file_entry: for alias in node.names: - imported_name = alias.asname if alias.asname else alias.name + imported_name = alias.asname or alias.name original_name = alias.name - helpers = file_entry.get(original_name, None) + helpers = file_entry.get(original_name) if helpers: + # Only add each possible helper name once + imported_set = imported_names_map[imported_name] for helper in helpers: - imported_names_map[imported_name].add(helper.qualified_name) - imported_names_map[imported_name].add(helper.fully_qualified_name) - + imported_set.add(helper.qualified_name) + imported_set.add(helper.fully_qualified_name) elif isinstance(node, ast.Import): - # Handle "import module" statements for alias in node.names: - imported_name = alias.asname if alias.asname else alias.name + imported_name = alias.asname or alias.name module_name = alias.name - for helper in helpers_by_file_get(module_name, []): - # For "import module" statements, functions would be called as module.function - full_call = f"{imported_name}.{helper.only_function_name}" - imported_names_map[full_call].add(helper.qualified_name) - imported_names_map[full_call].add(helper.fully_qualified_name) + helpers_list = helpers_by_file_get(module_name) + if helpers_list: + for helper in helpers_list: + # "import module": functions called as module.function + full_call = f"{imported_name}.{helper.only_function_name}" + callset = imported_names_map[full_call] + callset.add(helper.qualified_name) + callset.add(helper.fully_qualified_name) return dict(imported_names_map) @@ -616,88 +616,88 @@ def detect_unused_helper_functions( """ try: - # Parse the optimized code to analyze function calls and imports optimized_ast = ast.parse(optimized_code) - # Find the optimized entrypoint function - entrypoint_function_ast = None - for node in ast.walk(optimized_ast): - if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name: - entrypoint_function_ast = node - break + # Find the optimized entrypoint function early (using generator for early break) + entrypoint_function_name = function_to_optimize.function_name + entrypoint_function_ast = next( + ( + node + for node in ast.walk(optimized_ast) + if isinstance(node, ast.FunctionDef) and node.name == entrypoint_function_name + ), + None, + ) if not entrypoint_function_ast: - logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code") + logger.debug(f"Could not find entrypoint function {entrypoint_function_name} in optimized code") return [] - # First, analyze imports to build a mapping of imported names to their original qualified names imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context) - # Extract all function calls in the entrypoint function + # Extract all called function names in entrypoint AST, collecting variants in one pass called_function_names = set() + parents = getattr(function_to_optimize, "parents", None) + class_name = parents[0].name if parents else None + for node in ast.walk(entrypoint_function_ast): if isinstance(node, ast.Call): - if isinstance(node.func, ast.Name): - # Regular function call: function_name() - called_name = node.func.id + func = node.func + if isinstance(func, ast.Name): + called_name = func.id called_function_names.add(called_name) - # Also add the qualified name if this is an imported function if called_name in imported_names_map: called_function_names.update(imported_names_map[called_name]) - elif isinstance(node.func, ast.Attribute): - # Method call: obj.method() or self.method() or module.function() - if isinstance(node.func.value, ast.Name): - if node.func.value.id == "self": - # self.method_name() -> add both method_name and ClassName.method_name - called_function_names.add(node.func.attr) - # For class methods, also add the qualified name - if hasattr(function_to_optimize, "parents") and function_to_optimize.parents: - class_name = function_to_optimize.parents[0].name - called_function_names.add(f"{class_name}.{node.func.attr}") + elif isinstance(func, ast.Attribute): + val = func.value + attr_name = func.attr + # Method call: self.method() or module.function() or obj.method() + if isinstance(val, ast.Name): + val_id = val.id + if val_id == "self": + called_function_names.add(attr_name) + if class_name: + called_function_names.add(f"{class_name}.{attr_name}") else: - # obj.method() or module.function() - attr_name = node.func.attr called_function_names.add(attr_name) - called_function_names.add(f"{node.func.value.id}.{attr_name}") - # Check if this is a module.function call that maps to a helper - full_call = f"{node.func.value.id}.{attr_name}" + full_call = f"{val_id}.{attr_name}" + called_function_names.add(full_call) if full_call in imported_names_map: called_function_names.update(imported_names_map[full_call]) - # Handle nested attribute access like obj.attr.method() else: - called_function_names.add(node.func.attr) + # obj.attr.method() + called_function_names.add(attr_name) logger.debug(f"Functions called in optimized entrypoint: {called_function_names}") logger.debug(f"Imported names mapping: {imported_names_map}") - # Find helper functions that are no longer called + # Precompute entrypoint's file_path for fast comparison + entrypoint_file_path = function_to_optimize.file_path + + # Prefetch attributes to reduce lookup cost inside loop unused_helpers = [] for helper_function in code_context.helper_functions: - if helper_function.jedi_definition.type != "class": - # Check if the helper function is called using multiple name variants - helper_qualified_name = helper_function.qualified_name - helper_simple_name = helper_function.only_function_name - helper_fully_qualified_name = helper_function.fully_qualified_name - - # Create a set of all possible names this helper might be called by - possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name} - - # For cross-file helpers, also consider module-based calls - if helper_function.file_path != function_to_optimize.file_path: - # Add potential module.function combinations - module_name = helper_function.file_path.stem - possible_call_names.add(f"{module_name}.{helper_simple_name}") - - # Check if any of the possible names are in the called functions - is_called = bool(possible_call_names.intersection(called_function_names)) - - if not is_called: - unused_helpers.append(helper_function) - logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code") - logger.debug(f" Checked names: {possible_call_names}") - else: - logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code") - logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}") + jedi_type = helper_function.jedi_definition.type + if jedi_type == "class": + continue + + helper_qualified_name = helper_function.qualified_name + helper_simple_name = helper_function.only_function_name + helper_fully_qualified_name = helper_function.fully_qualified_name + possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name} + + # For cross-file helpers, add module.function variant + if helper_function.file_path != entrypoint_file_path: + module_name = helper_function.file_path.stem + possible_call_names.add(f"{module_name}.{helper_simple_name}") + + if not possible_call_names & called_function_names: + unused_helpers.append(helper_function) + logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code") + logger.debug(f" Checked names: {possible_call_names}") + else: + logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code") + logger.debug(f" Called via: {possible_call_names & called_function_names}") return unused_helpers