Skip to content

Commit c90c6b8

Browse files
committed
use a processpool
1 parent 571cfb9 commit c90c6b8

File tree

1 file changed

+160
-145
lines changed

1 file changed

+160
-145
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 160 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import subprocess
1010
import unittest
1111
from collections import defaultdict
12+
from concurrent.futures import ProcessPoolExecutor, as_completed
1213
from pathlib import Path
1314
from typing import TYPE_CHECKING, Callable, Optional
1415

@@ -18,12 +19,7 @@
1819
from rich.text import Text
1920

2021
from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
21-
from codeflash.code_utils.code_utils import (
22-
ImportErrorPattern,
23-
custom_addopts,
24-
get_run_tmp_file,
25-
module_name_from_file_path,
26-
)
22+
from codeflash.code_utils.code_utils import ImportErrorPattern, custom_addopts, get_run_tmp_file
2723
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
2824
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2925

@@ -288,157 +284,176 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
288284
return False, function_name, None
289285

290286

291-
def process_test_files(
292-
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
293-
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
287+
def _process_single_test_file(
288+
test_file: Path, functions: list[TestsInFile], project_root_path: Path, test_framework: str
289+
) -> tuple[Path, set]:
294290
import jedi
295291

296-
project_root_path = cfg.project_root_path
297-
test_framework = cfg.test_framework
292+
local_function_to_test_map = set()
298293

299-
function_to_test_map = defaultdict(set)
300-
jedi_project = jedi.Project(path=project_root_path)
294+
try:
295+
jedi_project = jedi.Project(path=project_root_path)
296+
script = jedi.Script(path=test_file, project=jedi_project)
297+
test_functions = set()
301298

302-
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
303-
progress,
304-
task_id,
305-
):
306-
for test_file, functions in file_to_test_map.items():
307-
try:
308-
script = jedi.Script(path=test_file, project=jedi_project)
309-
test_functions = set()
310-
311-
all_names = script.get_names(all_scopes=True, references=True)
312-
all_defs = script.get_names(all_scopes=True, definitions=True)
313-
all_names_top = script.get_names(all_scopes=True)
314-
315-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
316-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
317-
except Exception as e:
318-
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
319-
progress.advance(task_id)
320-
continue
299+
all_names = script.get_names(all_scopes=True, references=True)
300+
all_defs = script.get_names(all_scopes=True, definitions=True)
301+
all_names_top = script.get_names(all_scopes=True)
302+
303+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
304+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
305+
except Exception as e:
306+
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
307+
return test_file, local_function_to_test_map
308+
309+
if test_framework == "pytest":
310+
for function in functions:
311+
if "[" in function.test_function:
312+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
313+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
314+
if function_name in top_level_functions:
315+
test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type))
316+
elif function.test_function in top_level_functions:
317+
test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type))
318+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
319+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
320+
if base_name in top_level_functions:
321+
test_functions.add(
322+
TestFunction(
323+
function_name=base_name,
324+
test_class=function.test_class,
325+
parameters=function.test_function,
326+
test_type=function.test_type,
327+
)
328+
)
321329

322-
if test_framework == "pytest":
323-
for function in functions:
324-
if "[" in function.test_function:
325-
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
326-
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
327-
if function_name in top_level_functions:
330+
elif test_framework == "unittest":
331+
functions_to_search = [elem.test_function for elem in functions]
332+
test_suites = {elem.test_class for elem in functions}
333+
334+
matching_names = test_suites & top_level_classes.keys()
335+
for matched_name in matching_names:
336+
for def_name in all_defs:
337+
if (
338+
def_name.type == "function"
339+
and def_name.full_name is not None
340+
and f".{matched_name}." in def_name.full_name
341+
):
342+
for function in functions_to_search:
343+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
344+
345+
if is_parameterized and new_function == def_name.name:
328346
test_functions.add(
329-
TestFunction(function_name, function.test_class, parameters, function.test_type)
347+
TestFunction(
348+
function_name=def_name.name,
349+
test_class=matched_name,
350+
parameters=parameters,
351+
test_type=functions[0].test_type,
352+
)
330353
)
331-
elif function.test_function in top_level_functions:
332-
test_functions.add(
333-
TestFunction(function.test_function, function.test_class, None, function.test_type)
334-
)
335-
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
336-
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
337-
if base_name in top_level_functions:
354+
elif function == def_name.name:
338355
test_functions.add(
339356
TestFunction(
340-
function_name=base_name,
341-
test_class=function.test_class,
342-
parameters=function.test_function,
343-
test_type=function.test_type,
357+
function_name=def_name.name,
358+
test_class=matched_name,
359+
parameters=None,
360+
test_type=functions[0].test_type,
344361
)
345362
)
346363

347-
elif test_framework == "unittest":
348-
functions_to_search = [elem.test_function for elem in functions]
349-
test_suites = {elem.test_class for elem in functions}
350-
351-
matching_names = test_suites & top_level_classes.keys()
352-
for matched_name in matching_names:
353-
for def_name in all_defs:
354-
if (
355-
def_name.type == "function"
356-
and def_name.full_name is not None
357-
and f".{matched_name}." in def_name.full_name
358-
):
359-
for function in functions_to_search:
360-
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
361-
362-
if is_parameterized and new_function == def_name.name:
363-
test_functions.add(
364-
TestFunction(
365-
function_name=def_name.name,
366-
test_class=matched_name,
367-
parameters=parameters,
368-
test_type=functions[0].test_type,
369-
)
370-
)
371-
elif function == def_name.name:
372-
test_functions.add(
373-
TestFunction(
374-
function_name=def_name.name,
375-
test_class=matched_name,
376-
parameters=None,
377-
test_type=functions[0].test_type,
378-
)
379-
)
380-
381-
test_functions_list = list(test_functions)
382-
test_functions_raw = [elem.function_name for elem in test_functions_list]
383-
384-
test_functions_by_name = defaultdict(list)
385-
for i, func_name in enumerate(test_functions_raw):
386-
test_functions_by_name[func_name].append(i)
387-
388-
for name in all_names:
389-
if name.full_name is None:
390-
continue
391-
m = FUNCTION_NAME_REGEX.search(name.full_name)
392-
if not m:
393-
continue
394-
395-
scope = m.group(1)
396-
if scope not in test_functions_by_name:
397-
continue
398-
399-
try:
400-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
401-
except Exception as e:
402-
logger.debug(str(e))
403-
continue
404-
405-
if not definition or definition[0].type != "function":
406-
continue
407-
408-
definition_path = str(definition[0].module_path)
409-
if (
410-
definition_path.startswith(str(project_root_path) + os.sep)
411-
and definition[0].module_name != name.module_name
412-
and definition[0].full_name is not None
413-
):
414-
for index in test_functions_by_name[scope]:
415-
scope_test_function = test_functions_list[index].function_name
416-
scope_test_class = test_functions_list[index].test_class
417-
scope_parameters = test_functions_list[index].parameters
418-
test_type = test_functions_list[index].test_type
419-
420-
if scope_parameters is not None:
421-
if test_framework == "pytest":
422-
scope_test_function += "[" + scope_parameters + "]"
423-
if test_framework == "unittest":
424-
scope_test_function += "_" + scope_parameters
425-
426-
full_name_without_module_prefix = definition[0].full_name.replace(
427-
definition[0].module_name + ".", "", 1
428-
)
429-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
430-
431-
function_to_test_map[qualified_name_with_modules_from_root].add(
432-
FunctionCalledInTest(
433-
tests_in_file=TestsInFile(
434-
test_file=test_file,
435-
test_class=scope_test_class,
436-
test_function=scope_test_function,
437-
test_type=test_type,
438-
),
439-
position=CodePosition(line_no=name.line, col_no=name.column),
440-
)
441-
)
364+
test_functions_list = list(test_functions)
365+
test_functions_raw = [elem.function_name for elem in test_functions_list]
366+
367+
test_functions_by_name = defaultdict(list)
368+
for i, func_name in enumerate(test_functions_raw):
369+
test_functions_by_name[func_name].append(i)
370+
371+
for name in all_names:
372+
if name.full_name is None:
373+
continue
374+
m = FUNCTION_NAME_REGEX.search(name.full_name)
375+
if not m:
376+
continue
377+
378+
scope = m.group(1)
379+
if scope not in test_functions_by_name:
380+
continue
381+
382+
try:
383+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
384+
except Exception as e:
385+
logger.debug(str(e))
386+
continue
387+
388+
if not definition or definition[0].type != "function":
389+
continue
390+
391+
definition_path = str(definition[0].module_path)
392+
if (
393+
definition_path.startswith(str(project_root_path) + os.sep)
394+
and definition[0].module_name != name.module_name
395+
and definition[0].full_name is not None
396+
):
397+
for index in test_functions_by_name[scope]:
398+
scope_test_function = test_functions_list[index].function_name
399+
scope_test_class = test_functions_list[index].test_class
400+
scope_parameters = test_functions_list[index].parameters
401+
test_type = test_functions_list[index].test_type
402+
403+
if scope_parameters is not None:
404+
if test_framework == "pytest":
405+
scope_test_function += "[" + scope_parameters + "]"
406+
if test_framework == "unittest":
407+
scope_test_function += "_" + scope_parameters
408+
409+
full_name_without_module_prefix = definition[0].full_name.replace(
410+
definition[0].module_name + ".", "", 1
411+
)
412+
from codeflash.code_utils.code_utils import module_name_from_file_path
413+
414+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
415+
416+
function_called_in_test = FunctionCalledInTest(
417+
tests_in_file=TestsInFile(
418+
test_file=test_file,
419+
test_class=scope_test_class,
420+
test_function=scope_test_function,
421+
test_type=test_type,
422+
),
423+
position=CodePosition(line_no=name.line, col_no=name.column),
424+
)
425+
local_function_to_test_map.add((qualified_name_with_modules_from_root, function_called_in_test))
426+
427+
return test_file, local_function_to_test_map
428+
429+
430+
def process_test_files(
431+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
432+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
433+
project_root_path = cfg.project_root_path
434+
test_framework = cfg.test_framework
435+
436+
function_to_test_map = defaultdict(set)
437+
438+
with (
439+
test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
440+
progress,
441+
task_id,
442+
),
443+
ProcessPoolExecutor() as executor,
444+
):
445+
future_to_file = {
446+
executor.submit(
447+
_process_single_test_file, test_file, functions, project_root_path, test_framework
448+
): test_file
449+
for test_file, functions in file_to_test_map.items()
450+
}
451+
452+
for future in as_completed(future_to_file):
453+
_, local_results = future.result()
454+
455+
for qualified_name, function_called_in_test in local_results:
456+
function_to_test_map[qualified_name].add(function_called_in_test)
442457

443458
progress.advance(task_id)
444459

0 commit comments

Comments
 (0)