diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 7cd09c843..b4bfda3ff 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -7,7 +7,16 @@ from rich.console import Console from rich.logging import RichHandler -from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) from codeflash.cli_cmds.console_constants import SPINNER_TYPES from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT @@ -22,7 +31,15 @@ console = Console() logging.basicConfig( level=logging.INFO, - handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)], + handlers=[ + RichHandler( + rich_tracebacks=True, + markup=False, + console=console, + show_path=False, + show_time=False, + ) + ], format=BARE_LOGGING_FORMAT, ) @@ -31,7 +48,9 @@ def paneled_text( - text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None + text: str, + panel_args: dict[str, str | bool] | None = None, + text_args: dict[str, str] | None = None, ) -> None: """Print text in a panel.""" from rich.panel import Panel @@ -58,7 +77,9 @@ def code_print(code_str: str) -> None: @contextmanager -def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]: +def progress_bar( + message: str, *, transient: bool = False +) -> Generator[TaskID, None, None]: """Display a progress bar with a spinner and elapsed time.""" progress = Progress( SpinnerColumn(next(spinners)), @@ -70,3 +91,25 @@ def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, task = progress.add_task(message, total=None) with progress: yield task + + +@contextmanager +def test_files_progress_bar( + total: int, description: str +) -> Generator[tuple[Progress, TaskID], None, None]: + """Progress bar for test files.""" + with Progress( + SpinnerColumn(next(spinners)), + TextColumn("[progress.description]{task.description}"), + BarColumn( + complete_style="cyan", + finished_style="green", + pulse_style="yellow", + ), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + transient=True, + ) as progress: + task_id = progress.add_task(description, total=total) + yield progress, task_id diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 3b05c8d49..e26680e1a 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass from pytest import ExitCode -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import console, logger, test_files_progress_bar from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType @@ -30,14 +30,25 @@ class TestFunction: test_type: TestType +ERROR_PATTERN = re.compile(r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)") +PYTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"[\[\]]") +UNITTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"^test_\w+_\d+(?:_\w+)*") +UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX = re.compile(r"_\d+(?:_\w+)*$") +FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$") + + def discover_unit_tests( cfg: TestConfig, discover_only_these_tests: list[Path] | None = None ) -> dict[str, list[FunctionCalledInTest]]: - framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} + framework_strategies: dict[str, Callable] = { + "pytest": discover_tests_pytest, + "unittest": discover_tests_unittest, + } strategy = framework_strategies.get(cfg.test_framework, None) if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" raise ValueError(error_message) + return strategy(cfg, discover_only_these_tests) @@ -72,8 +83,7 @@ def discover_tests_pytest( if exitcode != 0: if exitcode == 2 and "ERROR collecting" in result.stdout: # Pattern matches "===== ERRORS =====" (any number of =) and captures everything after - error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)" - match = re.search(error_pattern, result.stdout) + match = ERROR_PATTERN.search(result.stdout) error_section = match.group(1) if match else result.stdout logger.warning( @@ -81,7 +91,9 @@ def discover_tests_pytest( ) elif 0 <= exitcode <= 5: - logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}") + logger.warning( + f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}" + ) else: logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}") console.rule() @@ -104,7 +116,10 @@ def discover_tests_pytest( test_function=test["test_function"], test_type=test_type, ) - if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests: + if ( + discover_only_these_tests + and test_obj.test_file not in discover_only_these_tests + ): continue file_to_test_map[test_obj.test_file].append(test_obj) # Within these test files, find the project functions they are referring to and return their names/locations @@ -129,7 +144,8 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: _test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py") _test_module_path = tests_root / _test_module_path if not _test_module_path.exists() or ( - discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests + discover_only_these_tests + and str(_test_module_path) not in discover_only_these_tests ): return None if "__replay_test" in str(_test_module_path): @@ -156,7 +172,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"): for test_2 in test._tests: if not hasattr(test_2, "_testMethodName"): - logger.warning(f"Didn't find tests for {test_2}") # it goes deeper? + logger.warning( + f"Didn't find tests for {test_2}" + ) # it goes deeper? continue details = get_test_details(test_2) if details is not None: @@ -181,124 +199,171 @@ def process_test_files( ) -> dict[str, list[FunctionCalledInTest]]: project_root_path = cfg.project_root_path test_framework = cfg.test_framework - function_to_test_map = defaultdict(list) + function_to_test_map = defaultdict(set) jedi_project = jedi.Project(path=project_root_path) + goto_cache = {} + + with test_files_progress_bar( + total=len(file_to_test_map), description="Processing test files" + ) as (progress, task_id): + + for test_file, functions in file_to_test_map.items(): + try: + script = jedi.Script(path=test_file, project=jedi_project) + test_functions = set() + + all_names = script.get_names(all_scopes=True, references=True) + all_defs = script.get_names(all_scopes=True, definitions=True) + all_names_top = script.get_names(all_scopes=True) + + top_level_functions = { + name.name: name for name in all_names_top if name.type == "function" + } + top_level_classes = { + name.name: name for name in all_names_top if name.type == "class" + } + except Exception as e: + logger.debug(f"Failed to get jedi script for {test_file}: {e}") + progress.advance(task_id) + continue - for test_file, functions in file_to_test_map.items(): - try: - script = jedi.Script(path=test_file, project=jedi_project) - test_functions = set() - - all_names = script.get_names(all_scopes=True, references=True) - all_defs = script.get_names(all_scopes=True, definitions=True) - all_names_top = script.get_names(all_scopes=True) - - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} - except Exception as e: - logger.debug(f"Failed to get jedi script for {test_file}: {e}") - continue - - if test_framework == "pytest": - for function in functions: - if "[" in function.test_function: - function_name = re.split(r"[\[\]]", function.test_function)[0] - parameters = re.split(r"[\[\]]", function.test_function)[1] - if function_name in top_level_functions: - test_functions.add( - TestFunction(function_name, function.test_class, parameters, function.test_type) - ) - elif function.test_function in top_level_functions: - test_functions.add( - TestFunction(function.test_function, function.test_class, None, function.test_type) - ) - elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function): - # Try to match parameterized unittest functions here, although we can't get the parameters. - # Extract base name by removing the numbered suffix and any additional descriptions - base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function) - if base_name in top_level_functions: + if test_framework == "pytest": + for function in functions: + if "[" in function.test_function: + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( + function.test_function + )[0] + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split( + function.test_function + )[1] + if function_name in top_level_functions: + test_functions.add( + TestFunction( + function_name, + function.test_class, + parameters, + function.test_type, + ) + ) + elif function.test_function in top_level_functions: test_functions.add( TestFunction( - function_name=base_name, - test_class=function.test_class, - parameters=function.test_function, - test_type=function.test_type, + function.test_function, + function.test_class, + None, + function.test_type, ) ) - - elif test_framework == "unittest": - functions_to_search = [elem.test_function for elem in functions] - test_suites = [elem.test_class for elem in functions] - - matching_names = test_suites & top_level_classes.keys() - for matched_name in matching_names: - for def_name in all_defs: - if ( - def_name.type == "function" - and def_name.full_name is not None - and f".{matched_name}." in def_name.full_name + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match( + function.test_function ): - for function in functions_to_search: - (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) - - if is_parameterized and new_function == def_name.name: - test_functions.add( - TestFunction( - function_name=def_name.name, - test_class=matched_name, - parameters=parameters, - test_type=functions[0].test_type, - ) # A test file must not have more than one test type + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub( + "", function.test_function + ) + if base_name in top_level_functions: + test_functions.add( + TestFunction( + function_name=base_name, + test_class=function.test_class, + parameters=function.test_function, + test_type=function.test_type, ) - elif function == def_name.name: - test_functions.add( - TestFunction( - function_name=def_name.name, - test_class=matched_name, - parameters=None, - test_type=functions[0].test_type, - ) + ) + + elif test_framework == "unittest": + functions_to_search = [elem.test_function for elem in functions] + test_suites = {elem.test_class for elem in functions} + + matching_names = test_suites & top_level_classes.keys() + for matched_name in matching_names: + for def_name in all_defs: + if ( + def_name.type == "function" + and def_name.full_name is not None + and f".{matched_name}." in def_name.full_name + ): + for function in functions_to_search: + (is_parameterized, new_function, parameters) = ( + discover_parameters_unittest(function) ) - test_functions_list = list(test_functions) - test_functions_raw = [elem.function_name for elem in test_functions_list] + if is_parameterized and new_function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=parameters, + test_type=functions[0].test_type, + ) + ) + elif function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=None, + test_type=functions[0].test_type, + ) + ) - for name in all_names: - if name.full_name is None: - continue - m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name) - if not m: - continue - scope = m.group(1) - indices = [i for i, x in enumerate(test_functions_raw) if x == scope] - for index in indices: - scope_test_function = test_functions_list[index].function_name - scope_test_class = test_functions_list[index].test_class - scope_parameters = test_functions_list[index].parameters - test_type = test_functions_list[index].test_type + test_functions_list = list(test_functions) + test_functions_raw = [elem.function_name for elem in test_functions_list] + + test_functions_by_name = defaultdict(list) + for i, func_name in enumerate(test_functions_raw): + test_functions_by_name[func_name].append(i) + + for name in all_names: + if name.full_name is None: + continue + m = FUNCTION_NAME_REGEX.search(name.full_name) + if not m: + continue + + scope = m.group(1) + if scope not in test_functions_by_name: + continue + + cache_key = (name.full_name, name.module_name) try: - definition = name.goto(follow_imports=True, follow_builtin_imports=False) + if cache_key in goto_cache: + definition = goto_cache[cache_key] + else: + definition = name.goto( + follow_imports=True, follow_builtin_imports=False + ) + goto_cache[cache_key] = definition except Exception as e: logger.debug(str(e)) continue - if definition and definition[0].type == "function": - definition_path = str(definition[0].module_path) - # The definition is part of this project and not defined within the original function - if ( - definition_path.startswith(str(project_root_path) + os.sep) - and definition[0].module_name != name.module_name - and definition[0].full_name is not None - ): + + if not definition or definition[0].type != "function": + continue + + definition_path = str(definition[0].module_path) + if ( + definition_path.startswith(str(project_root_path) + os.sep) + and definition[0].module_name != name.module_name + and definition[0].full_name is not None + ): + for index in test_functions_by_name[scope]: + scope_test_function = test_functions_list[index].function_name + scope_test_class = test_functions_list[index].test_class + scope_parameters = test_functions_list[index].parameters + test_type = test_functions_list[index].test_type + if scope_parameters is not None: if test_framework == "pytest": scope_test_function += "[" + scope_parameters + "]" if test_framework == "unittest": scope_test_function += "_" + scope_parameters - full_name_without_module_prefix = definition[0].full_name.replace( - definition[0].module_name + ".", "", 1 - ) + + full_name_without_module_prefix = definition[ + 0 + ].full_name.replace(definition[0].module_name + ".", "", 1) qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - function_to_test_map[qualified_name_with_modules_from_root].append( + + function_to_test_map[qualified_name_with_modules_from_root].add( FunctionCalledInTest( tests_in_file=TestsInFile( test_file=test_file, @@ -306,10 +371,12 @@ def process_test_files( test_function=scope_test_function, test_type=test_type, ), - position=CodePosition(line_no=name.line, col_no=name.column), + position=CodePosition( + line_no=name.line, col_no=name.column + ), ) ) - deduped_function_to_test_map = {} - for function, tests in function_to_test_map.items(): - deduped_function_to_test_map[function] = list(set(tests)) - return deduped_function_to_test_map + + progress.advance(task_id) + + return {function: list(tests) for function, tests in function_to_test_map.items()} diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 9adf3723f..8eae1014a 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import normalize_code, normalize_node from codeflash.code_utils.code_utils import get_run_tmp_file @@ -93,8 +93,6 @@ def run(self) -> None: logger.info("No functions found to optimize. Exiting…") return - console.rule() - logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root}…") console.rule() function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])