Skip to content

Commit 6867932

Browse files
committed
paralelize test discovery
1 parent 3eed53c commit 6867932

File tree

1 file changed

+93
-47
lines changed

1 file changed

+93
-47
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 93 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import hashlib
5+
import multiprocessing
56
import os
67
import pickle
78
import re
@@ -15,16 +16,9 @@
1516

1617
import pytest
1718
from pydantic.dataclasses import dataclass
18-
from rich.panel import Panel
19-
from rich.text import Text
2019

2120
from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
22-
from codeflash.code_utils.code_utils import (
23-
ImportErrorPattern,
24-
custom_addopts,
25-
get_run_tmp_file,
26-
module_name_from_file_path,
27-
)
21+
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path
2822
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
2923
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
3024

@@ -139,7 +133,7 @@ def close(self) -> None:
139133

140134
def discover_unit_tests(
141135
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
142-
) -> dict[str, list[FunctionCalledInTest]]:
136+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
143137
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
144138
strategy = framework_strategies.get(cfg.test_framework, None)
145139
if not strategy:
@@ -151,7 +145,7 @@ def discover_unit_tests(
151145

152146
def discover_tests_pytest(
153147
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
154-
) -> dict[Path, list[FunctionCalledInTest]]:
148+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
155149
tests_root = cfg.tests_root
156150
project_root = cfg.project_root_path
157151

@@ -187,10 +181,6 @@ def discover_tests_pytest(
187181
logger.warning(
188182
f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}"
189183
)
190-
if "ModuleNotFoundError" in result.stdout:
191-
match = ImportErrorPattern.search(result.stdout).group()
192-
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
193-
console.print(panel)
194184

195185
elif 0 <= exitcode <= 5:
196186
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}")
@@ -225,7 +215,7 @@ def discover_tests_pytest(
225215

226216
def discover_tests_unittest(
227217
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
228-
) -> dict[Path, list[FunctionCalledInTest]]:
218+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
229219
tests_root: Path = cfg.tests_root
230220
loader: unittest.TestLoader = unittest.TestLoader()
231221
tests: unittest.TestSuite = loader.discover(str(tests_root))
@@ -290,27 +280,39 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
290280

291281
def _process_single_test_file(
292282
test_file: Path, functions: list[TestsInFile], project_root_path: Path, test_framework: str
293-
) -> tuple[str, list[tuple[str, FunctionCalledInTest]]]:
283+
) -> tuple[str, list[tuple[str, FunctionCalledInTest]], int, list[dict]]:
294284
import jedi
295285

296286
jedi_project = jedi.Project(path=project_root_path)
297287
goto_cache = {}
298288
results = []
289+
cache_entries = []
299290

300291
try:
301292
script = jedi.Script(path=test_file, project=jedi_project)
302293
test_functions = set()
303294

304295
all_names = script.get_names(all_scopes=True, references=True)
305-
all_defs = script.get_names(all_scopes=True, definitions=True)
306-
all_names_top = script.get_names(all_scopes=True)
307-
308-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
309-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
296+
top_level_functions = {}
297+
top_level_classes = {}
298+
all_defs = []
299+
reference_names = []
300+
301+
for name in all_names:
302+
if name.type == "function":
303+
top_level_functions[name.name] = name
304+
if hasattr(name, "full_name") and name.full_name:
305+
all_defs.append(name)
306+
elif name.type == "class":
307+
top_level_classes[name.name] = name
308+
309+
if name.full_name is not None:
310+
m = FUNCTION_NAME_REGEX.search(name.full_name)
311+
if m:
312+
reference_names.append((name, m.group(1)))
310313
except Exception as e:
311314
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
312-
# tests_cache.close()
313-
return str(test_file), results
315+
return str(test_file), results, len(results), cache_entries
314316

315317
if test_framework == "pytest":
316318
for function in functions:
@@ -340,11 +342,8 @@ def _process_single_test_file(
340342
matching_names = test_suites & top_level_classes.keys()
341343
for matched_name in matching_names:
342344
for def_name in all_defs:
343-
if (
344-
def_name.type == "function"
345-
and def_name.full_name is not None
346-
and f".{matched_name}." in def_name.full_name
347-
):
345+
# all_defs already contains only functions, no need to check type
346+
if def_name.full_name is not None and f".{matched_name}." in def_name.full_name:
348347
for function in functions_to_search:
349348
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
350349

@@ -374,14 +373,7 @@ def _process_single_test_file(
374373
for i, func_name in enumerate(test_functions_raw):
375374
test_functions_by_name[func_name].append(i)
376375

377-
for name in all_names:
378-
if name.full_name is None:
379-
continue
380-
m = FUNCTION_NAME_REGEX.search(name.full_name)
381-
if not m:
382-
continue
383-
384-
scope = m.group(1)
376+
for name, scope in reference_names:
385377
if scope not in test_functions_by_name:
386378
continue
387379

@@ -432,28 +424,73 @@ def _process_single_test_file(
432424
)
433425
results.append((qualified_name_with_modules_from_root, function_called_in_test))
434426

435-
return str(test_file), results
427+
cache_entries.append(
428+
{
429+
"qualified_name_with_modules_from_root": qualified_name_with_modules_from_root,
430+
"function_name": scope,
431+
"test_class": scope_test_class,
432+
"test_function": scope_test_function,
433+
"test_type": test_type,
434+
"line_number": name.line,
435+
"col_number": name.column,
436+
}
437+
)
438+
439+
return str(test_file), results, len(results), cache_entries
436440

437441

438442
def process_test_files(
439443
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
440-
) -> dict[str, list[FunctionCalledInTest]]:
444+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
441445
project_root_path = cfg.project_root_path
442446
test_framework = cfg.test_framework
443447
function_to_test_map = defaultdict(set)
448+
total_count = 0
444449

445-
import multiprocessing
450+
tests_cache = TestsCache()
446451

447-
max_workers = min(len(file_to_test_map), multiprocessing.cpu_count())
448-
max_workers = max(1, max_workers)
452+
max_workers = min(len(file_to_test_map) or 1, multiprocessing.cpu_count())
449453

450454
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
451455
progress,
452456
task_id,
453457
):
454-
if len(file_to_test_map) == 1 or max_workers == 1:
455-
for test_file, functions in file_to_test_map.items():
456-
_, results = _process_single_test_file(test_file, functions, project_root_path, test_framework)
458+
cached_files = {}
459+
uncached_files = {}
460+
461+
for test_file, functions in file_to_test_map.items():
462+
file_hash = TestsCache.compute_file_hash(str(test_file))
463+
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
464+
465+
if cached_tests:
466+
cached_files[test_file] = (functions, cached_tests, file_hash)
467+
else:
468+
uncached_files[test_file] = functions
469+
470+
# Process cached files first
471+
for test_file, (_functions, cached_tests, file_hash) in cached_files.items():
472+
cur = tests_cache.cur
473+
cur.execute(
474+
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
475+
(str(test_file), file_hash),
476+
)
477+
qualified_names = [row[0] for row in cur.fetchall()]
478+
for cached_test, qualified_name in zip(cached_tests, qualified_names):
479+
function_to_test_map[qualified_name].add(cached_test)
480+
total_count += len(cached_tests)
481+
progress.advance(task_id)
482+
483+
if len(uncached_files) == 1 or max_workers == 1:
484+
for test_file, functions in uncached_files.items():
485+
_, results, count, cache_entries = _process_single_test_file(
486+
test_file, functions, project_root_path, test_framework
487+
)
488+
total_count += count
489+
490+
file_hash = TestsCache.compute_file_hash(str(test_file))
491+
for cache_entry in cache_entries:
492+
tests_cache.insert_test(file_path=str(test_file), file_hash=file_hash, **cache_entry)
493+
457494
for qualified_name, function_called in results:
458495
function_to_test_map[qualified_name].add(function_called)
459496
progress.advance(task_id)
@@ -463,12 +500,19 @@ def process_test_files(
463500
executor.submit(
464501
_process_single_test_file, test_file, functions, project_root_path, test_framework
465502
): test_file
466-
for test_file, functions in file_to_test_map.items()
503+
for test_file, functions in uncached_files.items()
467504
}
468505

469506
for future in as_completed(future_to_file):
470507
try:
471-
_, results = future.result()
508+
_, results, count, cache_entries = future.result()
509+
total_count += count
510+
511+
test_file = future_to_file[future]
512+
file_hash = TestsCache.compute_file_hash(str(test_file))
513+
for cache_entry in cache_entries:
514+
tests_cache.insert_test(file_path=str(test_file), file_hash=file_hash, **cache_entry)
515+
472516
for qualified_name, function_called in results:
473517
function_to_test_map[qualified_name].add(function_called)
474518
progress.advance(task_id)
@@ -477,4 +521,6 @@ def process_test_files(
477521
logger.error(f"Error processing test file {test_file}: {e}")
478522
progress.advance(task_id)
479523

480-
return {function: list(tests) for function, tests in function_to_test_map.items()}
524+
tests_cache.close()
525+
function_to_tests_dict = {function: list(tests) for function, tests in function_to_test_map.items()}
526+
return function_to_tests_dict, total_count

0 commit comments

Comments
 (0)