diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 7b9681854..0acfccd0e 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -107,9 +107,11 @@ def insert_test( ) self.connection.commit() - def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest] | None: + def get_function_to_test_map_for_file( + self, file_path: str, file_hash: str + ) -> dict[str, set[FunctionCalledInTest]] | None: cache_key = (file_path, file_hash) - if cache_key in self._memory_cache: + if cache_key in self.memory_cache: return self.memory_cache[cache_key] self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash)) @@ -117,15 +119,19 @@ def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCal if not rows: return None - result = [ - FunctionCalledInTest( + function_to_test_map = defaultdict(set) + + for row in rows: + qualified_name_with_modules_from_root = row[2] + function_called_in_test = FunctionCalledInTest( tests_in_file=TestsInFile( test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6])) ), position=CodePosition(line_no=row[7], col_no=row[8]), ) - for row in rows - ] + function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test) + + result = dict(function_to_test_map) self.memory_cache[cache_key] = result return result @@ -517,9 +523,12 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]: - function_parts = function_name.split("_") - if len(function_parts) > 1 and function_parts[-1].isdigit(): - return True, "_".join(function_parts[:-1]), function_parts[-1] + if "_" not in function_name: + return False, function_name, None + + function_parts = function_name.rsplit("_", 1) + if len(function_parts) == 2 and function_parts[1].isdigit(): + return True, function_parts[0], function_parts[1] return False, function_name, None @@ -552,28 +561,16 @@ def process_test_files( for test_file, functions in file_to_test_map.items(): file_hash = TestsCache.compute_file_hash(test_file) - cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) - - if cached_tests: - # Rebuild function_to_test_map from cached data - tests_cache.cur.execute( - "SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (str(test_file), file_hash) - ) - for row in tests_cache.cur.fetchall(): - qualified_name_with_modules_from_root = row[2] - test_type = TestType(int(row[6])) - - function_called_in_test = FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, test_class=row[4], test_function=row[5], test_type=test_type - ), - position=CodePosition(line_no=row[7], col_no=row[8]), - ) - - function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test) - if test_type == TestType.REPLAY_TEST: - num_discovered_replay_tests += 1 - num_discovered_tests += 1 + cached_function_to_test_map = tests_cache.get_function_to_test_map_for_file(str(test_file), file_hash) + + if cached_function_to_test_map: + for qualified_name, test_set in cached_function_to_test_map.items(): + function_to_test_map[qualified_name].update(test_set) + + for function_called_in_test in test_set: + if function_called_in_test.tests_in_file.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 + num_discovered_tests += 1 progress.advance(task_id) continue diff --git a/codeflash/models/models.py b/codeflash/models/models.py index c1a563672..687dc002f 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -363,7 +363,6 @@ class FunctionCalledInTest: tests_in_file: TestsInFile position: CodePosition - @dataclass(frozen=True) class CodePosition: line_no: int