Skip to content

Commit 236f14c

Browse files
committed
claude WIP
1 parent 4de9323 commit 236f14c

File tree

3 files changed

+632
-9
lines changed

3 files changed

+632
-9
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 182 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ruff: noqa: SLF001
22
from __future__ import annotations
33

4+
import ast
45
import hashlib
56
import os
67
import pickle
@@ -12,6 +13,9 @@
1213
from pathlib import Path
1314
from typing import TYPE_CHECKING, Callable, Optional
1415

16+
if TYPE_CHECKING:
17+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
18+
1519
import pytest
1620
from pydantic.dataclasses import dataclass
1721
from rich.panel import Panel
@@ -137,20 +141,175 @@ def close(self) -> None:
137141
self.connection.close()
138142

139143

144+
class ImportAnalyzer(ast.NodeVisitor):
145+
"""AST-based analyzer to find all imports in a test file."""
146+
147+
def __init__(self, function_names_to_find: set[str]) -> None:
148+
self.function_names_to_find = function_names_to_find
149+
self.imported_names: set[str] = set()
150+
self.imported_modules: set[str] = set()
151+
self.found_target_functions: set[str] = set()
152+
153+
def visit_Import(self, node: ast.Import) -> None:
154+
"""Handle 'import module' statements."""
155+
for alias in node.names:
156+
module_name = alias.asname if alias.asname else alias.name
157+
self.imported_modules.add(module_name)
158+
self.imported_names.add(module_name)
159+
self.generic_visit(node)
160+
161+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
162+
"""Handle 'from module import name' statements."""
163+
if node.module:
164+
self.imported_modules.add(node.module)
165+
166+
for alias in node.names:
167+
if alias.name == "*":
168+
# Star imports - we can't know what's imported, so be conservative
169+
self.imported_names.add("*")
170+
else:
171+
imported_name = alias.asname if alias.asname else alias.name
172+
self.imported_names.add(imported_name)
173+
174+
# Check if this import matches any target function
175+
if alias.name in self.function_names_to_find:
176+
self.found_target_functions.add(alias.name)
177+
self.generic_visit(node)
178+
179+
def visit_Call(self, node: ast.Call) -> None:
180+
"""Handle dynamic imports like importlib.import_module() or __import__()."""
181+
if isinstance(node.func, ast.Name) and node.func.id == "__import__" and node.args:
182+
# __import__("module_name")
183+
if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
184+
self.imported_modules.add(node.args[0].value)
185+
elif (isinstance(node.func, ast.Attribute)
186+
and isinstance(node.func.value, ast.Name)
187+
and node.func.value.id == "importlib"
188+
and node.func.attr == "import_module"
189+
and node.args):
190+
# importlib.import_module("module_name")
191+
if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
192+
self.imported_modules.add(node.args[0].value)
193+
self.generic_visit(node)
194+
195+
def visit_Name(self, node: ast.Name) -> None:
196+
"""Check if any name usage matches our target functions."""
197+
if node.id in self.function_names_to_find:
198+
self.found_target_functions.add(node.id)
199+
self.generic_visit(node)
200+
201+
def visit_Attribute(self, node: ast.Attribute) -> None:
202+
"""Handle module.function_name patterns."""
203+
if node.attr in self.function_names_to_find:
204+
self.found_target_functions.add(node.attr)
205+
self.generic_visit(node)
206+
207+
208+
def analyze_imports_in_test_file(test_file_path: Path, target_functions: set[str]) -> tuple[bool, set[str]]:
209+
"""Analyze imports in a test file to determine if it might test any target functions.
210+
211+
Args:
212+
test_file_path: Path to the test file
213+
target_functions: Set of function names we're looking for
214+
215+
Returns:
216+
Tuple of (should_process_with_jedi, found_function_names)
217+
218+
"""
219+
try:
220+
with test_file_path.open("r", encoding="utf-8") as f:
221+
content = f.read()
222+
223+
tree = ast.parse(content, filename=str(test_file_path))
224+
analyzer = ImportAnalyzer(target_functions)
225+
analyzer.visit(tree)
226+
227+
# If we found direct function matches, definitely process
228+
if analyzer.found_target_functions:
229+
return True, analyzer.found_target_functions
230+
231+
# If there are star imports, we need to be conservative
232+
if "*" in analyzer.imported_names:
233+
return True, set()
234+
235+
# Check for direct name matches first (higher priority)
236+
name_matches = analyzer.imported_names & target_functions
237+
if name_matches:
238+
return True, name_matches
239+
240+
# If no direct matches, check if any imported modules could contain our target functions
241+
# This is a heuristic - we look for common patterns
242+
potential_matches = set()
243+
for module in analyzer.imported_modules:
244+
# Check if module name suggests it could contain target functions
245+
for func_name in target_functions:
246+
# Only match if the module name is a prefix of the function qualified name
247+
func_parts = func_name.split(".")
248+
if len(func_parts) > 1 and module == func_parts[0]:
249+
# Module matches the first part of qualified name (e.g., mycode in mycode.target_function)
250+
# But only if we don't have specific import information suggesting otherwise
251+
potential_matches.add(func_name)
252+
elif any(part in module for part in func_name.split("_")) and len(func_name.split("_")) > 1:
253+
# Function name parts match module name (for underscore-separated names)
254+
potential_matches.add(func_name)
255+
256+
# Only use heuristic matches if we haven't found specific function imports that contradict them
257+
return bool(potential_matches), potential_matches
258+
259+
except (SyntaxError, UnicodeDecodeError, OSError) as e:
260+
logger.debug(f"Failed to analyze imports in {test_file_path}: {e}")
261+
# If we can't parse the file, be conservative and process it
262+
return True, set()
263+
264+
265+
def filter_test_files_by_imports(
266+
file_to_test_map: dict[Path, list[TestsInFile]],
267+
target_functions: set[str]
268+
) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]:
269+
"""Filter test files based on import analysis to reduce Jedi processing.
270+
271+
Args:
272+
file_to_test_map: Original mapping of test files to test functions
273+
target_functions: Set of function names we're optimizing
274+
275+
Returns:
276+
Tuple of (filtered_file_map, import_analysis_results)
277+
278+
"""
279+
if not target_functions:
280+
# If no target functions specified, process all files
281+
return file_to_test_map, {}
282+
283+
filtered_map = {}
284+
import_results = {}
285+
286+
for test_file, test_functions in file_to_test_map.items():
287+
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
288+
import_results[test_file] = found_functions
289+
290+
if should_process:
291+
filtered_map[test_file] = test_functions
292+
else:
293+
logger.debug(f"Skipping {test_file} - no relevant imports found")
294+
295+
logger.info(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files")
296+
return filtered_map, import_results
297+
298+
140299
def discover_unit_tests(
141-
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
300+
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None
142301
) -> dict[str, list[FunctionCalledInTest]]:
143302
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
144303
strategy = framework_strategies.get(cfg.test_framework, None)
145304
if not strategy:
146305
error_message = f"Unsupported test framework: {cfg.test_framework}"
147306
raise ValueError(error_message)
148307

149-
return strategy(cfg, discover_only_these_tests)
308+
return strategy(cfg, discover_only_these_tests, functions_to_optimize)
150309

151310

152311
def discover_tests_pytest(
153-
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
312+
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None
154313
) -> dict[Path, list[FunctionCalledInTest]]:
155314
tests_root = cfg.tests_root
156315
project_root = cfg.project_root_path
@@ -220,11 +379,11 @@ def discover_tests_pytest(
220379
continue
221380
file_to_test_map[test_obj.test_file].append(test_obj)
222381
# Within these test files, find the project functions they are referring to and return their names/locations
223-
return process_test_files(file_to_test_map, cfg)
382+
return process_test_files(file_to_test_map, cfg, functions_to_optimize)
224383

225384

226385
def discover_tests_unittest(
227-
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
386+
cfg: TestConfig, discover_only_these_tests: list[str] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None
228387
) -> dict[Path, list[FunctionCalledInTest]]:
229388
tests_root: Path = cfg.tests_root
230389
loader: unittest.TestLoader = unittest.TestLoader()
@@ -277,7 +436,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
277436
details = get_test_details(test)
278437
if details is not None:
279438
file_to_test_map[str(details.test_file)].append(details)
280-
return process_test_files(file_to_test_map, cfg)
439+
return process_test_files(file_to_test_map, cfg, functions_to_optimize)
281440

282441

283442
def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]:
@@ -289,13 +448,29 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
289448

290449

291450
def process_test_files(
292-
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
451+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig, functions_to_optimize: list[FunctionToOptimize] | None = None
293452
) -> dict[str, list[FunctionCalledInTest]]:
294453
import jedi
295454

296455
project_root_path = cfg.project_root_path
297456
test_framework = cfg.test_framework
298457

458+
# Apply import filter if functions to optimize are provided
459+
if functions_to_optimize:
460+
# Extract target function names from FunctionToOptimize objects
461+
# Include both qualified names and simple function names for better matching
462+
target_function_names = set()
463+
for func in functions_to_optimize:
464+
target_function_names.add(func.qualified_name_with_modules_from_root(project_root_path))
465+
target_function_names.add(func.function_name) # Add simple name too
466+
# Also add qualified name without module
467+
if func.parents:
468+
target_function_names.add(f"{func.parents[0].name}.{func.function_name}")
469+
470+
logger.debug(f"Target functions for import filtering: {target_function_names}")
471+
file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names)
472+
logger.debug(f"Import analysis results: {len(import_results)} files analyzed")
473+
299474
function_to_test_map = defaultdict(set)
300475
jedi_project = jedi.Project(path=project_root_path)
301476
goto_cache = {}

codeflash/optimization/optimizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,13 @@ def run(self) -> None:
162162

163163
console.rule()
164164
start_time = time.time()
165-
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
165+
# Extract all functions to optimize for import filtering
166+
all_functions_to_optimize = [
167+
func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list
168+
]
169+
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(
170+
self.test_cfg, functions_to_optimize=all_functions_to_optimize
171+
)
166172
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
167173
console.rule()
168174
logger.info(

0 commit comments

Comments
 (0)