|
1 | 1 | # ruff: noqa: SLF001 |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
| 4 | +import concurrent.futures |
4 | 5 | import hashlib |
| 6 | +import multiprocessing |
5 | 7 | import os |
6 | 8 | import pickle |
7 | 9 | import re |
|
23 | 25 | from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType |
24 | 26 |
|
25 | 27 | if TYPE_CHECKING: |
| 28 | + from multiprocessing.synchronize import Lock |
| 29 | + |
26 | 30 | from codeflash.verification.verification_utils import TestConfig |
27 | 31 |
|
28 | 32 |
|
@@ -284,23 +288,33 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N |
284 | 288 |
|
285 | 289 |
|
286 | 290 | def _process_single_test_file( |
287 | | - test_file: Path, functions: list[TestsInFile], project_root_path: Path, test_framework: str |
| 291 | + test_file: Path, |
| 292 | + functions: list[TestsInFile], |
| 293 | + project_root_path: Path, |
| 294 | + test_framework: str, |
| 295 | + jedi_lock: Optional[Lock] = None, |
288 | 296 | ) -> tuple[Path, set]: |
289 | 297 | import jedi |
290 | 298 |
|
291 | 299 | local_function_to_test_map = set() |
292 | 300 |
|
293 | 301 | try: |
294 | | - jedi_project = jedi.Project(path=project_root_path) |
295 | | - script = jedi.Script(path=test_file, project=jedi_project) |
296 | | - test_functions = set() |
297 | | - |
298 | | - all_names = script.get_names(all_scopes=True, references=True) |
299 | | - all_defs = script.get_names(all_scopes=True, definitions=True) |
300 | | - all_names_top = script.get_names(all_scopes=True) |
301 | | - |
302 | | - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
303 | | - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
| 302 | + if jedi_lock is not None: |
| 303 | + jedi_lock.acquire() |
| 304 | + try: |
| 305 | + jedi_project = jedi.Project(path=project_root_path) |
| 306 | + script = jedi.Script(path=test_file, project=jedi_project) |
| 307 | + test_functions = set() |
| 308 | + |
| 309 | + all_names = script.get_names(all_scopes=True, references=True) |
| 310 | + all_defs = script.get_names(all_scopes=True, definitions=True) |
| 311 | + all_names_top = script.get_names(all_scopes=True) |
| 312 | + |
| 313 | + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} |
| 314 | + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} |
| 315 | + finally: |
| 316 | + if jedi_lock is not None: |
| 317 | + jedi_lock.release() |
304 | 318 | except Exception as e: |
305 | 319 | logger.debug(f"Failed to get jedi script for {test_file}: {e}") |
306 | 320 | return test_file, local_function_to_test_map |
@@ -379,7 +393,13 @@ def _process_single_test_file( |
379 | 393 | continue |
380 | 394 |
|
381 | 395 | try: |
382 | | - definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
| 396 | + if jedi_lock is not None: |
| 397 | + jedi_lock.acquire() |
| 398 | + try: |
| 399 | + definition = name.goto(follow_imports=True, follow_builtin_imports=False) |
| 400 | + finally: |
| 401 | + if jedi_lock is not None: |
| 402 | + jedi_lock.release() |
383 | 403 | except Exception as e: |
384 | 404 | logger.debug(str(e)) |
385 | 405 | continue |
@@ -433,17 +453,25 @@ def process_test_files( |
433 | 453 | test_framework = cfg.test_framework |
434 | 454 |
|
435 | 455 | function_to_test_map = defaultdict(set) |
436 | | - |
437 | | - with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
438 | | - progress, |
439 | | - task_id, |
| 456 | + jedi_lock = multiprocessing.Manager().Lock() |
| 457 | + |
| 458 | + with ( |
| 459 | + test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( |
| 460 | + progress, |
| 461 | + task_id, |
| 462 | + ), |
| 463 | + concurrent.futures.ProcessPoolExecutor() as executor, |
440 | 464 | ): |
441 | | - for test_file, functions in file_to_test_map.items(): |
442 | | - _, local_results = _process_single_test_file(test_file, functions, project_root_path, test_framework) |
443 | | - |
| 465 | + futures = { |
| 466 | + executor.submit( |
| 467 | + _process_single_test_file, test_file, functions, project_root_path, test_framework, jedi_lock |
| 468 | + ): (test_file, functions) |
| 469 | + for test_file, functions in file_to_test_map.items() |
| 470 | + } |
| 471 | + for future in concurrent.futures.as_completed(futures): |
| 472 | + _, local_results = future.result() |
444 | 473 | for qualified_name, function_called_in_test in local_results: |
445 | 474 | function_to_test_map[qualified_name].add(function_called_in_test) |
446 | | - |
447 | 475 | progress.advance(task_id) |
448 | 476 |
|
449 | 477 | function_to_tests_dict = {function: list(tests) for function, tests in function_to_test_map.items()} |
|
0 commit comments