Skip to content

Commit 0000c29

Browse files
author
Archan Das
committed
revert to single-threaded if <25 files
1 parent c6d9d71 commit 0000c29

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def discover_tests_pytest(
109109
continue
110110
file_to_test_map[test_obj.test_file].append(test_obj)
111111
# Within these test files, find the project functions they are referring to and return their names/locations
112+
if len(file_to_test_map) < 25: #default to single-threaded if there aren't that many files
113+
return process_test_files_single_threaded(file_to_test_map, cfg)
112114
return process_test_files(file_to_test_map, cfg)
113115

114116

@@ -166,6 +168,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
166168
details = get_test_details(test)
167169
if details is not None:
168170
file_to_test_map[str(details.test_file)].append(details)
171+
172+
if len(file_to_test_map) < 25: #default to single-threaded if there aren't that many files
173+
return process_test_files_single_threaded(file_to_test_map, cfg)
169174
return process_test_files(file_to_test_map, cfg)
170175

171176

@@ -357,6 +362,144 @@ def process_file_worker(args_tuple):
357362
'results': {}
358363
}
359364

365+
def process_test_files_single_threaded(
366+
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
367+
) -> dict[str, list[FunctionCalledInTest]]:
368+
project_root_path = cfg.project_root_path
369+
test_framework = cfg.test_framework
370+
function_to_test_map = defaultdict(list)
371+
jedi_project = jedi.Project(path=project_root_path)
372+
373+
for test_file, functions in file_to_test_map.items():
374+
try:
375+
script = jedi.Script(path=test_file, project=jedi_project)
376+
test_functions = set()
377+
378+
all_names = script.get_names(all_scopes=True, references=True)
379+
all_defs = script.get_names(all_scopes=True, definitions=True)
380+
all_names_top = script.get_names(all_scopes=True)
381+
382+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
383+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
384+
except Exception as e:
385+
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
386+
continue
387+
388+
if test_framework == "pytest":
389+
for function in functions:
390+
if "[" in function.test_function:
391+
function_name = re.split(r"[\[\]]", function.test_function)[0]
392+
parameters = re.split(r"[\[\]]", function.test_function)[1]
393+
if function_name in top_level_functions:
394+
test_functions.add(
395+
TestFunction(function_name, function.test_class, parameters, function.test_type)
396+
)
397+
elif function.test_function in top_level_functions:
398+
test_functions.add(
399+
TestFunction(function.test_function, function.test_class, None, function.test_type)
400+
)
401+
elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
402+
# Try to match parameterized unittest functions here, although we can't get the parameters.
403+
# Extract base name by removing the numbered suffix and any additional descriptions
404+
base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function)
405+
if base_name in top_level_functions:
406+
test_functions.add(
407+
TestFunction(
408+
function_name=base_name,
409+
test_class=function.test_class,
410+
parameters=function.test_function,
411+
test_type=function.test_type,
412+
)
413+
)
414+
415+
elif test_framework == "unittest":
416+
functions_to_search = [elem.test_function for elem in functions]
417+
test_suites = [elem.test_class for elem in functions]
418+
419+
matching_names = test_suites & top_level_classes.keys()
420+
for matched_name in matching_names:
421+
for def_name in all_defs:
422+
if (
423+
def_name.type == "function"
424+
and def_name.full_name is not None
425+
and f".{matched_name}." in def_name.full_name
426+
):
427+
for function in functions_to_search:
428+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
429+
430+
if is_parameterized and new_function == def_name.name:
431+
test_functions.add(
432+
TestFunction(
433+
function_name=def_name.name,
434+
test_class=matched_name,
435+
parameters=parameters,
436+
test_type=functions[0].test_type,
437+
) # A test file must not have more than one test type
438+
)
439+
elif function == def_name.name:
440+
test_functions.add(
441+
TestFunction(
442+
function_name=def_name.name,
443+
test_class=matched_name,
444+
parameters=None,
445+
test_type=functions[0].test_type,
446+
)
447+
)
448+
449+
test_functions_list = list(test_functions)
450+
test_functions_raw = [elem.function_name for elem in test_functions_list]
451+
452+
for name in all_names:
453+
if name.full_name is None:
454+
continue
455+
m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name)
456+
if not m:
457+
continue
458+
scope = m.group(1)
459+
indices = [i for i, x in enumerate(test_functions_raw) if x == scope]
460+
for index in indices:
461+
scope_test_function = test_functions_list[index].function_name
462+
scope_test_class = test_functions_list[index].test_class
463+
scope_parameters = test_functions_list[index].parameters
464+
test_type = test_functions_list[index].test_type
465+
try:
466+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
467+
except Exception as e:
468+
logger.debug(str(e))
469+
continue
470+
if definition and definition[0].type == "function":
471+
definition_path = str(definition[0].module_path)
472+
# The definition is part of this project and not defined within the original function
473+
if (
474+
definition_path.startswith(str(project_root_path) + os.sep)
475+
and definition[0].module_name != name.module_name
476+
and definition[0].full_name is not None
477+
):
478+
if scope_parameters is not None:
479+
if test_framework == "pytest":
480+
scope_test_function += "[" + scope_parameters + "]"
481+
if test_framework == "unittest":
482+
scope_test_function += "_" + scope_parameters
483+
full_name_without_module_prefix = definition[0].full_name.replace(
484+
definition[0].module_name + ".", "", 1
485+
)
486+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
487+
function_to_test_map[qualified_name_with_modules_from_root].append(
488+
FunctionCalledInTest(
489+
tests_in_file=TestsInFile(
490+
test_file=test_file,
491+
test_class=scope_test_class,
492+
test_function=scope_test_function,
493+
test_type=test_type,
494+
),
495+
position=CodePosition(line_no=name.line, col_no=name.column),
496+
)
497+
)
498+
deduped_function_to_test_map = {}
499+
for function, tests in function_to_test_map.items():
500+
deduped_function_to_test_map[function] = list(set(tests))
501+
return deduped_function_to_test_map
502+
360503

361504
def process_test_files(
362505
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig

0 commit comments

Comments
 (0)