Skip to content

Commit 6f3f7af

Browse files
committed
Update discover_unit_tests.py
1 parent 9a49b0f commit 6f3f7af

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tempfile
1111
import unittest
1212
from collections import defaultdict
13+
from concurrent.futures import ProcessPoolExecutor, as_completed
1314
from pathlib import Path
1415
from typing import TYPE_CHECKING, Callable, Optional
1516

@@ -438,17 +439,28 @@ def process_test_files(
438439

439440
function_to_test_map = defaultdict(set)
440441

441-
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
442-
progress,
443-
task_id,
442+
with (
443+
ProcessPoolExecutor() as executor,
444+
test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
445+
progress,
446+
task_id,
447+
),
444448
):
449+
futures_map = {}
445450
for test_file, functions in file_to_test_map.items():
446-
_, local_results = _process_single_test_file(test_file, functions, project_root_path, test_framework)
447-
448-
for qualified_name, function_called_in_test in local_results:
449-
function_to_test_map[qualified_name].add(function_called_in_test)
450-
451-
progress.advance(task_id)
451+
future = executor.submit(_process_single_test_file, test_file, functions, project_root_path, test_framework)
452+
futures_map[future] = test_file
453+
454+
for future in as_completed(futures_map):
455+
original_test_file = futures_map[future]
456+
try:
457+
_, local_results = future.result()
458+
for qualified_name, function_called_in_test in local_results:
459+
function_to_test_map[qualified_name].add(function_called_in_test)
460+
except Exception as e:
461+
logger.error(f"Error processing test file {original_test_file}: {e}")
462+
finally:
463+
progress.advance(task_id)
452464

453465
function_to_tests_dict = {function: list(tests) for function, tests in function_to_test_map.items()}
454466
num_discovered_tests = sum(len(tests) for tests in function_to_tests_dict.values())

0 commit comments

Comments
 (0)