Skip to content

Commit b2ab78c

Browse files
committed
efficient passes
1 parent 9f246ec commit b2ab78c

File tree

1 file changed

+51
-42
lines changed

1 file changed

+51
-42
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,20 @@ def process_test_files(
503503
script = jedi.Script(path=test_file, project=jedi_project)
504504
test_functions = set()
505505

506-
all_names = script.get_names(all_scopes=True, references=True)
507-
all_defs = script.get_names(all_scopes=True, definitions=True)
508-
all_names_top = script.get_names(all_scopes=True)
509-
510-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
511-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
506+
# Single call to get all names with references and definitions
507+
all_names = script.get_names(all_scopes=True, references=True, definitions=True)
508+
509+
# Filter once and create lookup dictionaries
510+
top_level_functions = {}
511+
top_level_classes = {}
512+
all_defs = []
513+
514+
for name in all_names:
515+
if name.type == "function":
516+
top_level_functions[name.name] = name
517+
all_defs.append(name)
518+
elif name.type == "class":
519+
top_level_classes[name.name] = name
512520
except Exception as e:
513521
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
514522
progress.advance(task_id)
@@ -574,23 +582,24 @@ def process_test_files(
574582
)
575583

576584
test_functions_list = list(test_functions)
577-
test_functions_raw = [elem.function_name for elem in test_functions_list]
578-
579-
test_functions_by_name = defaultdict(list)
580-
for i, func_name in enumerate(test_functions_raw):
581-
test_functions_by_name[func_name].append(i)
582-
583-
for name in all_names:
584-
if name.full_name is None:
585-
continue
586-
m = FUNCTION_NAME_REGEX.search(name.full_name)
587-
if not m:
588-
continue
589-
590-
scope = m.group(1)
591-
if scope not in test_functions_by_name:
592-
continue
593-
585+
test_functions_by_name = {}
586+
for i, func in enumerate(test_functions_list):
587+
if func.function_name not in test_functions_by_name:
588+
test_functions_by_name[func.function_name] = []
589+
test_functions_by_name[func.function_name].append((i, func))
590+
591+
relevant_names = [
592+
(name, m.group(1))
593+
for name in all_names
594+
if (
595+
name.full_name is not None
596+
and (m := FUNCTION_NAME_REGEX.search(name.full_name))
597+
and m.group(1) in test_functions_by_name
598+
)
599+
]
600+
601+
# Process relevant names in batch
602+
for name, scope in relevant_names:
594603
try:
595604
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
596605
except Exception as e:
@@ -600,36 +609,36 @@ def process_test_files(
600609
if not definition or definition[0].type != "function":
601610
continue
602611

603-
definition_path = str(definition[0].module_path)
612+
definition_obj = definition[0]
613+
definition_path = str(definition_obj.module_path)
614+
604615
if (
605616
definition_path.startswith(str(project_root_path) + os.sep)
606-
and definition[0].module_name != name.module_name
607-
and definition[0].full_name is not None
617+
and definition_obj.module_name != name.module_name
618+
and definition_obj.full_name is not None
608619
):
609-
for index in test_functions_by_name[scope]:
610-
scope_test_function = test_functions_list[index].function_name
611-
scope_test_class = test_functions_list[index].test_class
612-
scope_parameters = test_functions_list[index].parameters
613-
test_type = test_functions_list[index].test_type
620+
# Pre-compute common values outside the inner loop
621+
full_name_without_module_prefix = definition_obj.full_name.replace(
622+
definition_obj.module_name + ".", "", 1
623+
)
624+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition_obj.module_path, project_root_path)}.{full_name_without_module_prefix}"
614625

615-
if scope_parameters is not None:
616-
if test_framework == "pytest":
617-
scope_test_function += "[" + scope_parameters + "]"
618-
if test_framework == "unittest":
619-
scope_test_function += "_" + scope_parameters
626+
for _index, test_func in test_functions_by_name[scope]:
627+
scope_test_function = test_func.function_name
620628

621-
full_name_without_module_prefix = definition[0].full_name.replace(
622-
definition[0].module_name + ".", "", 1
623-
)
624-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
629+
if test_func.parameters is not None:
630+
if test_framework == "pytest":
631+
scope_test_function += "[" + test_func.parameters + "]"
632+
elif test_framework == "unittest":
633+
scope_test_function += "_" + test_func.parameters
625634

626635
function_to_test_map[qualified_name_with_modules_from_root].add(
627636
FunctionCalledInTest(
628637
tests_in_file=TestsInFile(
629638
test_file=test_file,
630-
test_class=scope_test_class,
639+
test_class=test_func.test_class,
631640
test_function=scope_test_function,
632-
test_type=test_type,
641+
test_type=test_func.test_type,
633642
),
634643
position=CodePosition(line_no=name.line, col_no=name.column),
635644
)

0 commit comments

Comments
 (0)