Skip to content

Commit 86aa1cd

Browse files
author
Archan Das
committed
execute test discovery in threads
1 parent a035af8 commit 86aa1cd

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,19 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
178178

179179

180180
def process_test_files(
181-
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
181+
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
182182
) -> dict[str, list[FunctionCalledInTest]]:
183+
from concurrent.futures import ThreadPoolExecutor
184+
import os
185+
183186
project_root_path = cfg.project_root_path
184187
test_framework = cfg.test_framework
185188
function_to_test_map = defaultdict(list)
186189
jedi_project = jedi.Project(path=project_root_path)
187190

188-
for test_file, functions in file_to_test_map.items():
191+
# Define a function to process a single test file
192+
def process_single_file(test_file, functions):
193+
local_results = defaultdict(list)
189194
try:
190195
script = jedi.Script(path=test_file, project=jedi_project)
191196
test_functions = set()
@@ -198,7 +203,7 @@ def process_test_files(
198203
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
199204
except Exception as e:
200205
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
201-
continue
206+
return local_results
202207

203208
if test_framework == "pytest":
204209
for function in functions:
@@ -229,15 +234,15 @@ def process_test_files(
229234

230235
elif test_framework == "unittest":
231236
functions_to_search = [elem.test_function for elem in functions]
232-
test_suites = [elem.test_class for elem in functions]
237+
test_suites = {elem.test_class for elem in functions}
233238

234239
matching_names = test_suites & top_level_classes.keys()
235240
for matched_name in matching_names:
236241
for def_name in all_defs:
237242
if (
238-
def_name.type == "function"
239-
and def_name.full_name is not None
240-
and f".{matched_name}." in def_name.full_name
243+
def_name.type == "function"
244+
and def_name.full_name is not None
245+
and f".{matched_name}." in def_name.full_name
241246
):
242247
for function in functions_to_search:
243248
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
@@ -286,9 +291,9 @@ def process_test_files(
286291
definition_path = str(definition[0].module_path)
287292
# The definition is part of this project and not defined within the original function
288293
if (
289-
definition_path.startswith(str(project_root_path) + os.sep)
290-
and definition[0].module_name != name.module_name
291-
and definition[0].full_name is not None
294+
definition_path.startswith(str(project_root_path) + os.sep)
295+
and definition[0].module_name != name.module_name
296+
and definition[0].full_name is not None
292297
):
293298
if scope_parameters is not None:
294299
if test_framework == "pytest":
@@ -299,7 +304,7 @@ def process_test_files(
299304
definition[0].module_name + ".", "", 1
300305
)
301306
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
302-
function_to_test_map[qualified_name_with_modules_from_root].append(
307+
local_results[qualified_name_with_modules_from_root].append(
303308
FunctionCalledInTest(
304309
tests_in_file=TestsInFile(
305310
test_file=test_file,
@@ -310,7 +315,31 @@ def process_test_files(
310315
position=CodePosition(line_no=name.line, col_no=name.column),
311316
)
312317
)
318+
return local_results
319+
320+
# Determine number of workers (threads) - use fewer than processes since these are I/O bound
321+
max_workers = min(os.cpu_count() * 2 or 8, len(file_to_test_map), 16)
322+
323+
# Process files in parallel using threads (shared memory)
324+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
325+
futures = {
326+
executor.submit(process_single_file, test_file, functions): test_file
327+
for test_file, functions in file_to_test_map.items()
328+
}
329+
330+
# Collect results
331+
for future in futures:
332+
try:
333+
file_results = future.result()
334+
# Merge results
335+
for function, tests in file_results.items():
336+
function_to_test_map[function].extend(tests)
337+
except Exception as e:
338+
logger.warning(f"Error processing file {futures[future]}: {e}")
339+
340+
# Deduplicate results
313341
deduped_function_to_test_map = {}
314342
for function, tests in function_to_test_map.items():
315343
deduped_function_to_test_map[function] = list(set(tests))
344+
316345
return deduped_function_to_test_map

codeflash/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
2-
__version__ = "0.10.3"
3-
__version_tuple__ = (0, 10, 3)
2+
__version__ = "0.10.3.post6.dev0+a035af84"
3+
__version_tuple__ = (0, 10, 3, "post6", "dev0", "a035af84")

0 commit comments

Comments
 (0)