Skip to content
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ def module_name_from_file_path(file_path: Path, project_root_path: Path, *, trav
raise ValueError(msg) # noqa: B904


def get_qualified_function_path(file_path: Path, project_root_path: Path, qualified_name: str) -> str:
module_path = file_path.relative_to(project_root_path).with_suffix("").as_posix().replace("/", ".")
return f"{module_path}.{qualified_name}"


def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:
"""Get file path from module path."""
return project_root_path / (module_name.replace(".", os.sep) + ".py")
Expand Down
153 changes: 152 additions & 1 deletion codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,75 @@ class TestFunction:
FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")


def _extract_dotted_call_name(node: ast.expr) -> str | None:
"""Extract full dotted name from function call (e.g., 'src.math.computation.gcd_recursive')."""
parts = []
current = node
while isinstance(current, ast.Attribute):
parts.insert(0, current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.insert(0, current.id)
return ".".join(parts) if parts else None
return None


def _discover_calls_via_ast(
test_file: Path, test_functions: set[TestFunction], target_qualified_names: set[str]
) -> dict[str, list[tuple[TestFunction, CodePosition]]]:
try:
with test_file.open("r", encoding="utf-8") as f:
source = f.read()
tree = ast.parse(source, filename=str(test_file))
except (SyntaxError, FileNotFoundError) as e:
logger.debug(f"AST parsing failed for {test_file}: {e}")
return {}

import_map = {} # alias -> full_qualified_path
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname or alias.name
import_map[name] = alias.name
elif isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
if alias.name != "*":
full_name = f"{node.module}.{alias.name}"
name = alias.asname or alias.name
import_map[name] = full_name

test_funcs_by_name = {tf.function_name: tf for tf in test_functions}

result = defaultdict(list)

for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name not in test_funcs_by_name:
continue

test_func = test_funcs_by_name[node.name]

for child in ast.walk(node):
if not isinstance(child, ast.Call):
continue

call_name = _extract_dotted_call_name(child.func)
if not call_name:
continue

if call_name in target_qualified_names:
result[call_name].append((test_func, CodePosition(line_no=child.lineno, col_no=child.col_offset)))
continue

parts = call_name.split(".", 1)
if parts[0] in import_map:
resolved = f"{import_map[parts[0]]}.{parts[1]}" if len(parts) == 2 else import_map[parts[0]]

if resolved in target_qualified_names:
result[resolved].append((test_func, CodePosition(line_no=child.lineno, col_no=child.col_offset)))

return dict(result)


class TestsCache:
SCHEMA_VERSION = 1 # Increment this when schema changes

Expand Down Expand Up @@ -489,6 +558,7 @@ def discover_tests_pytest(
console.rule()
else:
logger.debug(f"Pytest collection exit code: {exitcode}")

if pytest_rootdir is not None:
cfg.tests_project_rootdir = Path(pytest_rootdir)
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
Expand All @@ -497,6 +567,8 @@ def discover_tests_pytest(
test_type = TestType.REPLAY_TEST
elif "test_concolic_coverage" in test["test_file"]:
test_type = TestType.CONCOLIC_COVERAGE_TEST
elif "test_hypothesis" in test["test_file"]:
test_type = TestType.HYPOTHESIS_TEST
else:
test_type = TestType.EXISTING_UNIT_TEST

Expand All @@ -509,6 +581,7 @@ def discover_tests_pytest(
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
continue
file_to_test_map[test_obj.test_file].append(test_obj)

# Within these test files, find the project functions they are referring to and return their names/locations
return process_test_files(file_to_test_map, cfg, functions_to_optimize)

Expand Down Expand Up @@ -540,6 +613,8 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
test_type = TestType.REPLAY_TEST
elif "test_concolic_coverage" in str(_test_module_path):
test_type = TestType.CONCOLIC_COVERAGE_TEST
elif "test_hypothesis" in str(_test_module_path):
test_type = TestType.HYPOTHESIS_TEST
else:
test_type = TestType.EXISTING_UNIT_TEST
return TestsInFile(
Expand Down Expand Up @@ -588,7 +663,9 @@ def process_test_files(
test_framework = cfg.test_framework

if functions_to_optimize:
target_function_names = {func.qualified_name for func in functions_to_optimize}
target_function_names = {
func.qualified_name_with_modules_from_root(project_root_path) for func in functions_to_optimize
}
file_to_test_map = filter_test_files_by_imports(file_to_test_map, target_function_names)

function_to_test_map = defaultdict(set)
Expand All @@ -598,6 +675,7 @@ def process_test_files(

tests_cache = TestsCache(project_root_path)
logger.info("!lsp|Discovering tests and processing unit tests")

with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
progress,
task_id,
Expand Down Expand Up @@ -698,6 +776,79 @@ def process_test_files(
test_functions_by_name[func.function_name].append(func)

test_function_names_set = set(test_functions_by_name.keys())

is_generated_test_file = (
any(
tf.test_type in (TestType.HYPOTHESIS_TEST, TestType.CONCOLIC_COVERAGE_TEST) for tf in test_functions
)
if test_functions
else any(
func.test_type in (TestType.HYPOTHESIS_TEST, TestType.CONCOLIC_COVERAGE_TEST) for func in functions
)
)

# For generated tests, use AST-based discovery since Jedi often fails
if is_generated_test_file and functions_to_optimize:
logger.debug(f"Using AST-based discovery for generated test file: {test_file.name}")
target_qualified_names = {
func.qualified_name_with_modules_from_root(project_root_path) for func in functions_to_optimize
}

if not test_functions:
logger.debug("Jedi found no functions, building test_functions from collected functions")
test_functions = {
TestFunction(
function_name=func.test_function,
test_class=func.test_class,
parameters=None,
test_type=func.test_type,
)
for func in functions
}

ast_results = _discover_calls_via_ast(test_file, test_functions, target_qualified_names)

for qualified_name, matches in ast_results.items():
for test_func, position in matches:
if test_func.parameters is not None:
if test_framework == "pytest":
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
else: # unittest
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
else:
scope_test_function = test_func.function_name

function_to_test_map[qualified_name].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=test_func.test_class,
test_function=scope_test_function,
test_type=test_func.test_type,
),
position=position,
)
)
tests_cache.insert_test(
file_path=str(test_file),
file_hash=file_hash,
qualified_name_with_modules_from_root=qualified_name,
function_name=test_func.function_name,
test_class=test_func.test_class or "",
test_function=scope_test_function,
test_type=test_func.test_type,
line_number=position.line_no,
col_number=position.col_no,
)

if test_func.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1

num_discovered_tests += 1

progress.advance(task_id)
continue

relevant_names = []

names_with_full_name = [name for name in all_names if name.full_name is not None]
Expand Down
2 changes: 2 additions & 0 deletions codeflash/models/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TestType(Enum):
REPLAY_TEST = 4
CONCOLIC_COVERAGE_TEST = 5
INIT_STATE_TEST = 6
HYPOTHESIS_TEST = 7

def to_name(self) -> str:
if self is TestType.INIT_STATE_TEST:
Expand All @@ -18,5 +19,6 @@ def to_name(self) -> str:
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests",
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
TestType.HYPOTHESIS_TEST: "🔮 Hypothesis Tests",
}
return names[self]
Loading
Loading