Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 123 additions & 61 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,25 @@ class TestFunction:
test_type: TestType


ERROR_PATTERN = re.compile(r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)")
PYTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"[\[\]]")
UNITTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"^test_\w+_\d+(?:_\w+)*")
UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX = re.compile(r"_\d+(?:_\w+)*$")
FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")


def discover_unit_tests(
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
) -> dict[str, list[FunctionCalledInTest]]:
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
framework_strategies: dict[str, Callable] = {
"pytest": discover_tests_pytest,
"unittest": discover_tests_unittest,
}
strategy = framework_strategies.get(cfg.test_framework, None)
if not strategy:
error_message = f"Unsupported test framework: {cfg.test_framework}"
raise ValueError(error_message)

return strategy(cfg, discover_only_these_tests)


Expand Down Expand Up @@ -73,16 +84,17 @@ def discover_tests_pytest(
if exitcode != 0:
if exitcode == 2 and "ERROR collecting" in result.stdout:
# Pattern matches "===== ERRORS =====" (any number of =) and captures everything after
error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
match = re.search(error_pattern, result.stdout)
match = ERROR_PATTERN.search(result.stdout)
error_section = match.group(1) if match else result.stdout

logger.warning(
f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}\n {error_section}"
)

elif 0 <= exitcode <= 5:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}")
logger.warning(
f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}"
)
else:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
console.rule()
Expand All @@ -105,7 +117,10 @@ def discover_tests_pytest(
test_function=test["test_function"],
test_type=test_type,
)
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
if (
discover_only_these_tests
and test_obj.test_file not in discover_only_these_tests
):
continue
file_to_test_map[test_obj.test_file].append(test_obj)
# Within these test files, find the project functions they are referring to and return their names/locations
Expand All @@ -130,7 +145,8 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
_test_module_path = tests_root / _test_module_path
if not _test_module_path.exists() or (
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
discover_only_these_tests
and str(_test_module_path) not in discover_only_these_tests
):
return None
if "__replay_test" in str(_test_module_path):
Expand All @@ -157,7 +173,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"):
for test_2 in test._tests:
if not hasattr(test_2, "_testMethodName"):
logger.warning(f"Didn't find tests for {test_2}") # it goes deeper?
logger.warning(
f"Didn't find tests for {test_2}"
) # it goes deeper?
continue
details = get_test_details(test_2)
if details is not None:
Expand All @@ -182,8 +200,9 @@ def process_test_files(
) -> dict[str, list[FunctionCalledInTest]]:
project_root_path = cfg.project_root_path
test_framework = cfg.test_framework
function_to_test_map = defaultdict(list)
function_to_test_map = defaultdict(set)
jedi_project = jedi.Project(path=project_root_path)
goto_cache = {}

for test_file, functions in file_to_test_map.items():
try:
Expand All @@ -194,29 +213,51 @@ def process_test_files(
all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)

top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
top_level_functions = {
name.name: name for name in all_names_top if name.type == "function"
}
top_level_classes = {
name.name: name for name in all_names_top if name.type == "class"
}
except Exception as e:
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
continue

if test_framework == "pytest":
for function in functions:
if "[" in function.test_function:
function_name = re.split(r"[\[\]]", function.test_function)[0]
parameters = re.split(r"[\[\]]", function.test_function)[1]
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
function.test_function
)[0]
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
function.test_function
)[1]
if function_name in top_level_functions:
test_functions.add(
TestFunction(function_name, function.test_class, parameters, function.test_type)
TestFunction(
function_name,
function.test_class,
parameters,
function.test_type,
)
)
elif function.test_function in top_level_functions:
test_functions.add(
TestFunction(function.test_function, function.test_class, None, function.test_type)
TestFunction(
function.test_function,
function.test_class,
None,
function.test_type,
)
)
elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(
function.test_function
):
# Try to match parameterized unittest functions here, although we can't get the parameters.
# Extract base name by removing the numbered suffix and any additional descriptions
base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function)
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub(
"", function.test_function
)
if base_name in top_level_functions:
test_functions.add(
TestFunction(
Expand All @@ -229,7 +270,7 @@ def process_test_files(

elif test_framework == "unittest":
functions_to_search = [elem.test_function for elem in functions]
test_suites = [elem.test_class for elem in functions]
test_suites = {elem.test_class for elem in functions}

matching_names = test_suites & top_level_classes.keys()
for matched_name in matching_names:
Expand All @@ -240,7 +281,9 @@ def process_test_files(
and f".{matched_name}." in def_name.full_name
):
for function in functions_to_search:
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
(is_parameterized, new_function, parameters) = (
discover_parameters_unittest(function)
)

if is_parameterized and new_function == def_name.name:
test_functions.add(
Expand All @@ -264,53 +307,72 @@ def process_test_files(
test_functions_list = list(test_functions)
test_functions_raw = [elem.function_name for elem in test_functions_list]

test_functions_by_name = defaultdict(list)
for i, func_name in enumerate(test_functions_raw):
test_functions_by_name[func_name].append(i)

for name in all_names:
if name.full_name is None:
continue
m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name)
m = FUNCTION_NAME_REGEX.search(name.full_name)
if not m:
continue

scope = m.group(1)
indices = [i for i, x in enumerate(test_functions_raw) if x == scope]
for index in indices:
scope_test_function = test_functions_list[index].function_name
scope_test_class = test_functions_list[index].test_class
scope_parameters = test_functions_list[index].parameters
test_type = test_functions_list[index].test_type
try:
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
logger.debug(str(e))
continue
if definition and definition[0].type == "function":
definition_path = str(definition[0].module_path)
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(str(project_root_path) + os.sep)
and definition[0].module_name != name.module_name
and definition[0].full_name is not None
):
if scope_parameters is not None:
if test_framework == "pytest":
scope_test_function += "[" + scope_parameters + "]"
if test_framework == "unittest":
scope_test_function += "_" + scope_parameters
full_name_without_module_prefix = definition[0].full_name.replace(
definition[0].module_name + ".", "", 1
)
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
function_to_test_map[qualified_name_with_modules_from_root].append(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=scope_test_class,
test_function=scope_test_function,
test_type=test_type,
),
position=CodePosition(line_no=name.line, col_no=name.column),
)
if scope not in test_functions_by_name:
continue

cache_key = (name.full_name, name.module_name)
try:
if cache_key in goto_cache:
definition = goto_cache[cache_key]
else:
definition = name.goto(
follow_imports=True, follow_builtin_imports=False
)
goto_cache[cache_key] = definition
except Exception as e:
logger.debug(str(e))
continue

if not definition or definition[0].type != "function":
continue

definition_path = str(definition[0].module_path)
if (
definition_path.startswith(str(project_root_path) + os.sep)
and definition[0].module_name != name.module_name
and definition[0].full_name is not None
):
for index in test_functions_by_name[scope]:
scope_test_function = test_functions_list[index].function_name
scope_test_class = test_functions_list[index].test_class
scope_parameters = test_functions_list[index].parameters
test_type = test_functions_list[index].test_type

if scope_parameters is not None:
if test_framework == "pytest":
scope_test_function += "[" + scope_parameters + "]"
if test_framework == "unittest":
scope_test_function += "_" + scope_parameters

full_name_without_module_prefix = definition[0].full_name.replace(
definition[0].module_name + ".", "", 1
)
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"

function_to_test_map[qualified_name_with_modules_from_root].add(
FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file,
test_class=scope_test_class,
test_function=scope_test_function,
test_type=test_type,
),
position=CodePosition(
line_no=name.line, col_no=name.column
),
)
deduped_function_to_test_map = {}
for function, tests in function_to_test_map.items():
deduped_function_to_test_map[function] = list(set(tests))
return deduped_function_to_test_map
)

return {function: list(tests) for function, tests in function_to_test_map.items()}
7 changes: 3 additions & 4 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING

from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.code_utils import get_run_tmp_file
Expand Down Expand Up @@ -95,9 +95,8 @@ def run(self) -> None:
return

console.rule()
logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root}…")
console.rule()
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
with progress_bar(f"Discovering existing unit tests in {self.test_cfg.tests_root}…", transient=True):
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
console.rule()
logger.info(f"Discovered {num_discovered_tests} existing unit tests in {self.test_cfg.tests_root}")
Expand Down
Loading