diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 1fd86acce..98223669b 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Optional +import jedi import pytest from pydantic.dataclasses import dataclass from rich.panel import Panel @@ -288,192 +289,160 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N return False, function_name, None -def process_test_files( - file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig -) -> dict[str, list[FunctionCalledInTest]]: - import jedi - +def _process_single_test_file( + test_file: Path, + functions: list[TestsInFile], + cfg: TestConfig, + jedi_project: jedi.Project, + function_to_test_map: defaultdict, +) -> None: project_root_path = cfg.project_root_path test_framework = cfg.test_framework - function_to_test_map = defaultdict(set) - jedi_project = jedi.Project(path=project_root_path) - goto_cache = {} - tests_cache = TestsCache() + try: + script = jedi.Script(path=test_file, project=jedi_project) + test_functions = set() - with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( - progress, - task_id, - ): - for test_file, functions in file_to_test_map.items(): - file_hash = TestsCache.compute_file_hash(test_file) - cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) - if cached_tests: - self_cur = tests_cache.cur - self_cur.execute( - "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?", - (str(test_file), file_hash), - ) - qualified_names = [row[0] for row in self_cur.fetchall()] - for cached, qualified_name in zip(cached_tests, qualified_names): - function_to_test_map[qualified_name].add(cached) - progress.advance(task_id) - continue + all_names = script.get_names(all_scopes=True, references=True) + all_defs = script.get_names(all_scopes=True, definitions=True) + all_names_top = script.get_names(all_scopes=True) - try: - script = jedi.Script(path=test_file, project=jedi_project) - test_functions = set() + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} + except Exception as e: + logger.debug(f"Failed to get jedi script for {test_file}: {e}") + return + + if test_framework == "pytest": + for function in functions: + if "[" in function.test_function: + function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] + parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] + if function_name in top_level_functions: + test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type)) + elif function.test_function in top_level_functions: + test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type)) + elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): + base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) + if base_name in top_level_functions: + test_functions.add( + TestFunction( + function_name=base_name, + test_class=function.test_class, + parameters=function.test_function, + test_type=function.test_type, + ) + ) - all_names = script.get_names(all_scopes=True, references=True) - all_defs = script.get_names(all_scopes=True, definitions=True) - all_names_top = script.get_names(all_scopes=True) + elif test_framework == "unittest": + functions_to_search = [elem.test_function for elem in functions] + test_suites = {elem.test_class for elem in functions} - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} - except Exception as e: - logger.debug(f"Failed to get jedi script for {test_file}: {e}") - progress.advance(task_id) - continue + matching_names = test_suites & top_level_classes.keys() + for matched_name in matching_names: + for def_name in all_defs: + if ( + def_name.type == "function" + and def_name.full_name is not None + and f".{matched_name}." in def_name.full_name + ): + for function in functions_to_search: + (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) - if test_framework == "pytest": - for function in functions: - if "[" in function.test_function: - function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0] - parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1] - if function_name in top_level_functions: + if is_parameterized and new_function == def_name.name: test_functions.add( - TestFunction(function_name, function.test_class, parameters, function.test_type) + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=parameters, + test_type=functions[0].test_type, + ) ) - elif function.test_function in top_level_functions: - test_functions.add( - TestFunction(function.test_function, function.test_class, None, function.test_type) - ) - elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function): - base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function) - if base_name in top_level_functions: + elif function == def_name.name: test_functions.add( TestFunction( - function_name=base_name, - test_class=function.test_class, - parameters=function.test_function, - test_type=function.test_type, + function_name=def_name.name, + test_class=matched_name, + parameters=None, + test_type=functions[0].test_type, ) ) - elif test_framework == "unittest": - functions_to_search = [elem.test_function for elem in functions] - test_suites = {elem.test_class for elem in functions} - - matching_names = test_suites & top_level_classes.keys() - for matched_name in matching_names: - for def_name in all_defs: - if ( - def_name.type == "function" - and def_name.full_name is not None - and f".{matched_name}." in def_name.full_name - ): - for function in functions_to_search: - (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) - - if is_parameterized and new_function == def_name.name: - test_functions.add( - TestFunction( - function_name=def_name.name, - test_class=matched_name, - parameters=parameters, - test_type=functions[0].test_type, - ) - ) - elif function == def_name.name: - test_functions.add( - TestFunction( - function_name=def_name.name, - test_class=matched_name, - parameters=None, - test_type=functions[0].test_type, - ) - ) - - test_functions_list = list(test_functions) - test_functions_raw = [elem.function_name for elem in test_functions_list] - - test_functions_by_name = defaultdict(list) - for i, func_name in enumerate(test_functions_raw): - test_functions_by_name[func_name].append(i) - - for name in all_names: - if name.full_name is None: - continue - m = FUNCTION_NAME_REGEX.search(name.full_name) - if not m: - continue - - scope = m.group(1) - if scope not in test_functions_by_name: - continue - - cache_key = (name.full_name, name.module_name) - try: - if cache_key in goto_cache: - definition = goto_cache[cache_key] - else: - definition = name.goto(follow_imports=True, follow_builtin_imports=False) - goto_cache[cache_key] = definition - except Exception as e: - logger.debug(str(e)) - continue - - if not definition or definition[0].type != "function": - continue - - definition_path = str(definition[0].module_path) - if ( - definition_path.startswith(str(project_root_path) + os.sep) - and definition[0].module_name != name.module_name - and definition[0].full_name is not None - ): - for index in test_functions_by_name[scope]: - scope_test_function = test_functions_list[index].function_name - scope_test_class = test_functions_list[index].test_class - scope_parameters = test_functions_list[index].parameters - test_type = test_functions_list[index].test_type - - if scope_parameters is not None: - if test_framework == "pytest": - scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": - scope_test_function += "_" + scope_parameters - - full_name_without_module_prefix = definition[0].full_name.replace( - definition[0].module_name + ".", "", 1 - ) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + test_functions_list = list(test_functions) + test_functions_raw = [elem.function_name for elem in test_functions_list] + + test_functions_by_name = defaultdict(list) + for i, func_name in enumerate(test_functions_raw): + test_functions_by_name[func_name].append(i) - tests_cache.insert_test( - file_path=str(test_file), - file_hash=file_hash, - qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, - function_name=scope, + for name in all_names: + if name.full_name is None: + continue + m = FUNCTION_NAME_REGEX.search(name.full_name) + if not m: + continue + + scope = m.group(1) + if scope not in test_functions_by_name: + continue + + try: + definition = name.goto(follow_imports=True, follow_builtin_imports=False) + except Exception as e: + logger.debug(str(e)) + continue + + if not definition or definition[0].type != "function": + continue + + definition_path = str(definition[0].module_path) + if ( + definition_path.startswith(str(project_root_path) + os.sep) + and definition[0].module_name != name.module_name + and definition[0].full_name is not None + ): + for index in test_functions_by_name[scope]: + scope_test_function = test_functions_list[index].function_name + scope_test_class = test_functions_list[index].test_class + scope_parameters = test_functions_list[index].parameters + test_type = test_functions_list[index].test_type + + if scope_parameters is not None: + if test_framework == "pytest": + scope_test_function += "[" + scope_parameters + "]" + if test_framework == "unittest": + scope_test_function += "_" + scope_parameters + + full_name_without_module_prefix = definition[0].full_name.replace( + definition[0].module_name + ".", "", 1 + ) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + + function_to_test_map[qualified_name_with_modules_from_root].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, test_class=scope_test_class, test_function=scope_test_function, test_type=test_type, - line_number=name.line, - col_number=name.column, - ) + ), + position=CodePosition(line_no=name.line, col_no=name.column), + ) + ) - function_to_test_map[qualified_name_with_modules_from_root].add( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - ), - position=CodePosition(line_no=name.line, col_no=name.column), - ) - ) +def process_test_files( + file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig +) -> dict[str, list[FunctionCalledInTest]]: + function_to_test_map = defaultdict(set) + jedi_project = jedi.Project(path=cfg.project_root_path) + + with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( + progress, + task_id, + ): + for test_file, functions in file_to_test_map.items(): + _process_single_test_file(test_file, functions, cfg, jedi_project, function_to_test_map) progress.advance(task_id) - tests_cache.close() return {function: list(tests) for function, tests in function_to_test_map.items()}