Skip to content

Commit e61d445

Browse files
committed
debug
1 parent 576f58c commit e61d445

File tree

1 file changed

+48
-20
lines changed

1 file changed

+48
-20
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# ruff: noqa: SLF001
22
from __future__ import annotations
33

4+
import concurrent.futures
45
import hashlib
6+
import multiprocessing
57
import os
68
import pickle
79
import re
@@ -23,6 +25,8 @@
2325
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2426

2527
if TYPE_CHECKING:
28+
from multiprocessing.synchronize import Lock
29+
2630
from codeflash.verification.verification_utils import TestConfig
2731

2832

@@ -284,23 +288,33 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
284288

285289

286290
def _process_single_test_file(
287-
test_file: Path, functions: list[TestsInFile], project_root_path: Path, test_framework: str
291+
test_file: Path,
292+
functions: list[TestsInFile],
293+
project_root_path: Path,
294+
test_framework: str,
295+
jedi_lock: Optional[Lock] = None,
288296
) -> tuple[Path, set]:
289297
import jedi
290298

291299
local_function_to_test_map = set()
292300

293301
try:
294-
jedi_project = jedi.Project(path=project_root_path)
295-
script = jedi.Script(path=test_file, project=jedi_project)
296-
test_functions = set()
297-
298-
all_names = script.get_names(all_scopes=True, references=True)
299-
all_defs = script.get_names(all_scopes=True, definitions=True)
300-
all_names_top = script.get_names(all_scopes=True)
301-
302-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
303-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
302+
if jedi_lock is not None:
303+
jedi_lock.acquire()
304+
try:
305+
jedi_project = jedi.Project(path=project_root_path)
306+
script = jedi.Script(path=test_file, project=jedi_project)
307+
test_functions = set()
308+
309+
all_names = script.get_names(all_scopes=True, references=True)
310+
all_defs = script.get_names(all_scopes=True, definitions=True)
311+
all_names_top = script.get_names(all_scopes=True)
312+
313+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
314+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
315+
finally:
316+
if jedi_lock is not None:
317+
jedi_lock.release()
304318
except Exception as e:
305319
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
306320
return test_file, local_function_to_test_map
@@ -379,7 +393,13 @@ def _process_single_test_file(
379393
continue
380394

381395
try:
382-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
396+
if jedi_lock is not None:
397+
jedi_lock.acquire()
398+
try:
399+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
400+
finally:
401+
if jedi_lock is not None:
402+
jedi_lock.release()
383403
except Exception as e:
384404
logger.debug(str(e))
385405
continue
@@ -433,17 +453,25 @@ def process_test_files(
433453
test_framework = cfg.test_framework
434454

435455
function_to_test_map = defaultdict(set)
436-
437-
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
438-
progress,
439-
task_id,
456+
jedi_lock = multiprocessing.Manager().Lock()
457+
458+
with (
459+
test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
460+
progress,
461+
task_id,
462+
),
463+
concurrent.futures.ProcessPoolExecutor() as executor,
440464
):
441-
for test_file, functions in file_to_test_map.items():
442-
_, local_results = _process_single_test_file(test_file, functions, project_root_path, test_framework)
443-
465+
futures = {
466+
executor.submit(
467+
_process_single_test_file, test_file, functions, project_root_path, test_framework, jedi_lock
468+
): (test_file, functions)
469+
for test_file, functions in file_to_test_map.items()
470+
}
471+
for future in concurrent.futures.as_completed(futures):
472+
_, local_results = future.result()
444473
for qualified_name, function_called_in_test in local_results:
445474
function_to_test_map[qualified_name].add(function_called_in_test)
446-
447475
progress.advance(task_id)
448476

449477
function_to_tests_dict = {function: list(tests) for function, tests in function_to_test_map.items()}

0 commit comments

Comments
 (0)