Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,25 +107,31 @@ def insert_test(
)
self.connection.commit()

def get_tests_for_file(self, file_path: str, file_hash: str) -> list[FunctionCalledInTest] | None:
def get_function_to_test_map_for_file(
self, file_path: str, file_hash: str
) -> dict[str, set[FunctionCalledInTest]] | None:
cache_key = (file_path, file_hash)
if cache_key in self._memory_cache:
if cache_key in self.memory_cache:
return self.memory_cache[cache_key]

self.cur.execute("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (file_path, file_hash))
rows = self.cur.fetchall()
if not rows:
return None

result = [
FunctionCalledInTest(
function_to_test_map = defaultdict(set)

for row in rows:
qualified_name_with_modules_from_root = row[2]
function_called_in_test = FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=Path(row[0]), test_class=row[4], test_function=row[5], test_type=TestType(int(row[6]))
),
position=CodePosition(line_no=row[7], col_no=row[8]),
)
for row in rows
]
function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test)

result = dict(function_to_test_map)
self.memory_cache[cache_key] = result
return result

Expand Down Expand Up @@ -517,9 +523,12 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:


def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]:
function_parts = function_name.split("_")
if len(function_parts) > 1 and function_parts[-1].isdigit():
return True, "_".join(function_parts[:-1]), function_parts[-1]
if "_" not in function_name:
return False, function_name, None

function_parts = function_name.rsplit("_", 1)
if len(function_parts) == 2 and function_parts[1].isdigit():
return True, function_parts[0], function_parts[1]

return False, function_name, None

Expand Down Expand Up @@ -552,28 +561,16 @@ def process_test_files(
for test_file, functions in file_to_test_map.items():
file_hash = TestsCache.compute_file_hash(test_file)

cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)

if cached_tests:
# Rebuild function_to_test_map from cached data
tests_cache.cur.execute(
"SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?", (str(test_file), file_hash)
)
for row in tests_cache.cur.fetchall():
qualified_name_with_modules_from_root = row[2]
test_type = TestType(int(row[6]))

function_called_in_test = FunctionCalledInTest(
tests_in_file=TestsInFile(
test_file=test_file, test_class=row[4], test_function=row[5], test_type=test_type
),
position=CodePosition(line_no=row[7], col_no=row[8]),
)

function_to_test_map[qualified_name_with_modules_from_root].add(function_called_in_test)
if test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1
num_discovered_tests += 1
cached_function_to_test_map = tests_cache.get_function_to_test_map_for_file(str(test_file), file_hash)

if cached_function_to_test_map:
for qualified_name, test_set in cached_function_to_test_map.items():
function_to_test_map[qualified_name].update(test_set)

for function_called_in_test in test_set:
if function_called_in_test.tests_in_file.test_type == TestType.REPLAY_TEST:
num_discovered_replay_tests += 1
num_discovered_tests += 1

progress.advance(task_id)
continue
Expand Down
1 change: 0 additions & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ class FunctionCalledInTest:
tests_in_file: TestsInFile
position: CodePosition


@dataclass(frozen=True)
class CodePosition:
line_no: int
Expand Down
Loading