diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 00e9ce436..72abe235a 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -8,6 +8,7 @@ from _ast import AsyncFunctionDef, ClassDef, FunctionDef from collections import defaultdict from functools import cache +from itertools import islice from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -15,7 +16,7 @@ import libcst as cst from pydantic.dataclasses import dataclass -from codeflash.api.cfapi import get_blocklisted_functions, make_cfapi_request, is_function_being_optimized_again +from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again from codeflash.cli_cmds.console import DEBUG_MODE, console, logger from codeflash.code_utils.code_utils import ( is_class_defined_in_file, @@ -153,38 +154,37 @@ def get_code_context_hash(self) -> str: to uniquely identify the function for optimization tracking. """ try: - with open(self.file_path, 'r', encoding='utf-8') as f: - file_content = f.read() - - # Extract the function's code content - lines = file_content.splitlines() + # Read only the necessary lines if possible, otherwise fallback to full file. if self.starting_line is not None and self.ending_line is not None: - # Use line numbers if available (1-indexed to 0-indexed) - function_content = '\n'.join(lines[self.starting_line - 1:self.ending_line]) + # Efficiently read only relevant function lines + start = self.starting_line - 1 # convert to 0-indexed + end = self.ending_line # exclusive + with open(self.file_path, encoding="utf-8") as f: + function_lines = list(islice(f, start, end)) + function_content = "".join(function_lines).strip() else: # Fallback: use the entire file content if line numbers aren't available - function_content = file_content + with open(self.file_path, encoding="utf-8") as f: + function_content = f.read().strip() - # Create a context string that includes: - # - File path (relative to make it portable) - # - Qualified function name - # - Function code content + # Create a context string that includes filename (for portability), + # qualified function name, and function code content. context_parts = [ str(self.file_path.name), # Just filename for portability self.qualified_name, - function_content.strip() + function_content, ] - - context_string = '\n---\n'.join(context_parts) + context_string = "\n---\n".join(context_parts) # Generate SHA-256 hash - return hashlib.sha256(context_string.encode('utf-8')).hexdigest() + return hashlib.sha256(context_string.encode("utf-8")).hexdigest() - except (OSError, IOError) as e: + except OSError as e: logger.warning(f"Could not read file {self.file_path} for hashing: {e}") # Fallback hash using available metadata fallback_string = f"{self.file_path.name}:{self.qualified_name}" - return hashlib.sha256(fallback_string.encode('utf-8')).hexdigest() + return hashlib.sha256(fallback_string.encode("utf-8")).hexdigest() + def get_functions_to_optimize( optimize_all: str | None, @@ -228,7 +228,7 @@ def get_functions_to_optimize( found_function = None for fn in functions.get(file, []): if only_function_name == fn.function_name and ( - class_name is None or class_name == fn.top_level_parent_name + class_name is None or class_name == fn.top_level_parent_name ): found_function = fn if found_function is None: @@ -307,7 +307,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( - replay_test: Path, test_cfg: TestConfig, project_root_path: Path + replay_test: Path, test_cfg: TestConfig, project_root_path: Path ) -> dict[Path, list[FunctionToOptimize]]: function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) # Get the absolute file paths for each function, excluding class name if present @@ -322,7 +322,7 @@ def get_all_replay_test_functions( class_name = ( module_path_parts[-1] if module_path_parts - and is_class_defined_in_file( + and is_class_defined_in_file( module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py") ) else None @@ -374,8 +374,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]: class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): def __init__( - self, file_name: Path, function_or_method_name: str, class_name: str | None = None, - line_no: int | None = None + self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None ) -> None: self.file_name = file_name self.class_name = class_name @@ -406,13 +405,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name: self.is_top_level = True if any( - isinstance(decorator, ast.Name) and decorator.id == "classmethod" - for decorator in body_node.decorator_list + isinstance(decorator, ast.Name) and decorator.id == "classmethod" + for decorator in body_node.decorator_list ): self.is_classmethod = True elif any( - isinstance(decorator, ast.Name) and decorator.id == "staticmethod" - for decorator in body_node.decorator_list + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list ): self.is_staticmethod = True return @@ -421,13 +420,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: # This way, if we don't have the class name, we can still find the static method for body_node in node.body: if ( - isinstance(body_node, ast.FunctionDef) - and body_node.name == self.function_name - and body_node.lineno in {self.line_no, self.line_no + 1} - and any( - isinstance(decorator, ast.Name) and decorator.id == "staticmethod" - for decorator in body_node.decorator_list - ) + isinstance(body_node, ast.FunctionDef) + and body_node.name == self.function_name + and body_node.lineno in {self.line_no, self.line_no + 1} + and any( + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list + ) ): self.is_staticmethod = True self.is_top_level = True @@ -460,10 +459,7 @@ def inspect_top_level_functions_or_methods( def check_optimization_status( - functions_by_file: dict[Path, list[FunctionToOptimize]], - owner: str, - repo: str, - pr_number: int + functions_by_file: dict[Path, list[FunctionToOptimize]], owner: str, repo: str, pr_number: int ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: """Check which functions have already been optimized and filter them out. @@ -480,6 +476,7 @@ def check_optimization_status( Returns: Tuple of (filtered_functions_dict, remaining_count) + """ # Build the code_contexts dictionary for the API call code_contexts = {} @@ -500,7 +497,6 @@ def check_optimization_status( result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts) already_optimized_paths = set(result.get("already_optimized_paths", [])) - # Filter out already optimized functions filtered_functions = defaultdict(list) remaining_count = 0 @@ -556,12 +552,12 @@ def filter_functions( test_functions_removed_count += len(_functions) continue if file_path in ignore_paths or any( - file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths + file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths ): ignore_paths_removed_count += 1 continue if file_path in submodule_paths or any( - file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths + file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths ): submodule_ignored_paths_count += 1 continue @@ -579,12 +575,14 @@ def filter_functions( if blocklist_funcs: functions_tmp = [] for function in _functions: - if not ( + if ( function.file_path.name in blocklist_funcs and function.qualified_name in blocklist_funcs[function.file_path.name] ): + # This function is in blocklist, we can skip it blocklist_funcs_removed_count += 1 continue + # This function is NOT in blocklist. we can keep it functions_tmp.append(function) _functions = functions_tmp @@ -609,9 +607,7 @@ def filter_functions( owner, repo = get_repo_owner_and_name(repository) pr_number = get_pr_number() if owner and repo and pr_number is not None: - path_based_functions, functions_count = check_optimization_status( - path_based_functions, owner, repo, pr_number - ) + path_based_functions, functions_count = check_optimization_status(path_based_functions, owner, repo, pr_number) initial_count = sum(len(funcs) for funcs in filtered_modified_functions.values()) already_optimized_count = initial_count - functions_count @@ -652,8 +648,8 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list if submodule_paths is None: submodule_paths = ignored_submodule_paths(module_root) return not ( - file_path in submodule_paths - or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) + file_path in submodule_paths + or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) ) @@ -662,4 +658,4 @@ def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool: - return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list) \ No newline at end of file + return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list)