Skip to content

Commit ecbceec

Browse files
author
Codeflash Bot
committed
tests modified now
1 parent ffff5f1 commit ecbceec

File tree

3 files changed

+174
-53
lines changed

3 files changed

+174
-53
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,11 +738,14 @@ def inject_profiling_into_existing_test(
738738
ast.Import(names=[ast.alias(name="time")]),
739739
ast.Import(names=[ast.alias(name="gc")]),
740740
ast.Import(names=[ast.alias(name="os")]),
741-
ast.Import(names=[ast.alias(name="inspect")]),
742741
]
743742
if mode == TestingMode.BEHAVIOR:
744743
new_imports.extend(
745-
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
744+
[
745+
ast.Import(names=[ast.alias(name="inspect")]),
746+
ast.Import(names=[ast.alias(name="sqlite3")]),
747+
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
748+
]
746749
)
747750
if test_framework == "unittest" and platform.system() != "Windows":
748751
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))

codeflash/discovery/discover_unit_tests.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None:
212212
self.wildcard_modules: set[str] = set()
213213
# Track aliases: alias_name -> original_name
214214
self.alias_mapping: dict[str, str] = {}
215+
# Track instances: variable_name -> class_name
216+
self.instance_mapping: dict[str, str] = {}
215217

216218
# Precompute function_names for prefix search
217219
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
218220
self._exact_names = function_names_to_find
219221
self._prefix_roots: dict[str, list[str]] = {}
222+
# Precompute sets for faster lookup during visit_Attribute()
223+
self._dot_names: set[str] = set()
224+
self._dot_methods: dict[str, set[str]] = {}
225+
self._class_method_to_target: dict[tuple[str, str], str] = {}
220226
for name in function_names_to_find:
221227
if "." in name:
222-
root = name.split(".", 1)[0]
223-
self._prefix_roots.setdefault(root, []).append(name)
228+
root, method = name.rsplit(".", 1)
229+
self._dot_names.add(name)
230+
self._dot_methods.setdefault(method, set()).add(root)
231+
self._class_method_to_target[(root, method)] = name
232+
root_prefix = name.split(".", 1)[0]
233+
self._prefix_roots.setdefault(root_prefix, []).append(name)
224234

225235
def visit_Import(self, node: ast.Import) -> None:
226236
"""Handle 'import module' statements."""
@@ -247,6 +257,24 @@ def visit_Import(self, node: ast.Import) -> None:
247257
self.found_qualified_name = target_func
248258
return
249259

260+
def visit_Assign(self, node: ast.Assign) -> None:
261+
"""Track variable assignments, especially class instantiations."""
262+
if self.found_any_target_function:
263+
return
264+
265+
# Check if the assignment is a class instantiation
266+
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
267+
class_name = node.value.func.id
268+
if class_name in self.imported_modules:
269+
# Track all target variables as instances of the imported class
270+
for target in node.targets:
271+
if isinstance(target, ast.Name):
272+
# Map the variable to the actual class name (handling aliases)
273+
original_class = self.alias_mapping.get(class_name, class_name)
274+
self.instance_mapping[target.id] = original_class
275+
276+
self.generic_visit(node)
277+
250278
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
251279
"""Handle 'from module import name' statements."""
252280
if self.found_any_target_function:
@@ -287,6 +315,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
287315
self.found_qualified_name = qname
288316
return
289317

318+
# Check if any target function is a method of the imported class/module
319+
# Be conservative except when an alias is used (which requires exact method matching)
320+
for target_func in fnames:
321+
if "." in target_func:
322+
class_name, method_name = target_func.split(".", 1)
323+
if aname == class_name and not alias.asname:
324+
# If an alias is used, don't match conservatively
325+
# The actual method usage should be detected in visit_Attribute
326+
self.found_any_target_function = True
327+
self.found_qualified_name = target_func
328+
return
329+
290330
prefix = qname + "."
291331
# Only bother if one of the targets startswith the prefix-root
292332
candidates = proots.get(qname, ())
@@ -301,33 +341,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
301341
if self.found_any_target_function:
302342
return
303343

344+
# Check if this is accessing a target function through an imported module
345+
346+
node_value = node.value
347+
node_attr = node.attr
348+
304349
# Check if this is accessing a target function through an imported module
305350
if (
306-
isinstance(node.value, ast.Name)
307-
and node.value.id in self.imported_modules
308-
and node.attr in self.function_names_to_find
351+
isinstance(node_value, ast.Name)
352+
and node_value.id in self.imported_modules
353+
and node_attr in self.function_names_to_find
309354
):
310355
self.found_any_target_function = True
311-
self.found_qualified_name = node.attr
356+
self.found_qualified_name = node_attr
312357
return
313358

314-
if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
315-
for target_func in self.function_names_to_find:
316-
if "." in target_func:
317-
class_name, method_name = target_func.rsplit(".", 1)
318-
if node.attr == method_name:
319-
imported_name = node.value.id
320-
original_name = self.alias_mapping.get(imported_name, imported_name)
321-
if original_name == class_name:
322-
self.found_any_target_function = True
323-
self.found_qualified_name = target_func
324-
return
325-
326-
# Check if this is accessing a target function through a dynamically imported module
327-
# Only if we've detected dynamic imports are being used
328-
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
359+
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
360+
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
361+
roots_possible = self._dot_methods.get(node_attr)
362+
if roots_possible:
363+
imported_name = node_value.id
364+
original_name = self.alias_mapping.get(imported_name, imported_name)
365+
if original_name in roots_possible:
366+
self.found_any_target_function = True
367+
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
368+
return
369+
370+
# Check if this is accessing a method on an instance variable
371+
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
372+
class_name = self.instance_mapping[node_value.id]
373+
roots_possible = self._dot_methods.get(node_attr)
374+
if roots_possible and class_name in roots_possible:
375+
self.found_any_target_function = True
376+
self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)]
377+
return
378+
379+
# Check for dynamic import match
380+
if self.has_dynamic_imports and node_attr in self.function_names_to_find:
329381
self.found_any_target_function = True
330-
self.found_qualified_name = node.attr
382+
self.found_qualified_name = node_attr
331383
return
332384

333385
self.generic_visit(node)

0 commit comments

Comments
 (0)