Skip to content

Commit 17c7886

Browse files
committed
efficient passes
1 parent 9f246ec commit 17c7886

File tree

1 file changed

+45
-40
lines changed

1 file changed

+45
-40
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,20 @@ def process_test_files(
503503
script = jedi.Script(path=test_file, project=jedi_project)
504504
test_functions = set()
505505

506-
all_names = script.get_names(all_scopes=True, references=True)
507-
all_defs = script.get_names(all_scopes=True, definitions=True)
508-
all_names_top = script.get_names(all_scopes=True)
509-
510-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
511-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
506+
# Single call to get all names with references and definitions
507+
all_names = script.get_names(all_scopes=True, references=True, definitions=True)
508+
509+
# Filter once and create lookup dictionaries
510+
top_level_functions = {}
511+
top_level_classes = {}
512+
all_defs = []
513+
514+
for name in all_names:
515+
if name.type == "function":
516+
top_level_functions[name.name] = name
517+
all_defs.append(name)
518+
elif name.type == "class":
519+
top_level_classes[name.name] = name
512520
except Exception as e:
513521
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
514522
progress.advance(task_id)
@@ -573,24 +581,21 @@ def process_test_files(
573581
)
574582
)
575583

576-
test_functions_list = list(test_functions)
577-
test_functions_raw = [elem.function_name for elem in test_functions_list]
578-
579584
test_functions_by_name = defaultdict(list)
580-
for i, func_name in enumerate(test_functions_raw):
581-
test_functions_by_name[func_name].append(i)
585+
for func in test_functions:
586+
test_functions_by_name[func.function_name].append(func)
582587

583-
for name in all_names:
584-
if name.full_name is None:
585-
continue
586-
m = FUNCTION_NAME_REGEX.search(name.full_name)
587-
if not m:
588-
continue
588+
test_function_names_set = set(test_functions_by_name.keys())
589+
relevant_names = []
589590

590-
scope = m.group(1)
591-
if scope not in test_functions_by_name:
592-
continue
591+
names_with_full_name = [name for name in all_names if name.full_name is not None]
593592

593+
for name in names_with_full_name:
594+
match = FUNCTION_NAME_REGEX.search(name.full_name)
595+
if match and match.group(1) in test_function_names_set:
596+
relevant_names.append((name, match.group(1)))
597+
598+
for name, scope in relevant_names:
594599
try:
595600
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
596601
except Exception as e:
@@ -600,36 +605,36 @@ def process_test_files(
600605
if not definition or definition[0].type != "function":
601606
continue
602607

603-
definition_path = str(definition[0].module_path)
608+
definition_obj = definition[0]
609+
definition_path = str(definition_obj.module_path)
610+
611+
project_root_str = str(project_root_path)
604612
if (
605-
definition_path.startswith(str(project_root_path) + os.sep)
606-
and definition[0].module_name != name.module_name
607-
and definition[0].full_name is not None
613+
definition_path.startswith(project_root_str + os.sep)
614+
and definition_obj.module_name != name.module_name
615+
and definition_obj.full_name is not None
608616
):
609-
for index in test_functions_by_name[scope]:
610-
scope_test_function = test_functions_list[index].function_name
611-
scope_test_class = test_functions_list[index].test_class
612-
scope_parameters = test_functions_list[index].parameters
613-
test_type = test_functions_list[index].test_type
617+
# Pre-compute common values outside the inner loop
618+
module_prefix = definition_obj.module_name + "."
619+
full_name_without_module_prefix = definition_obj.full_name.replace(module_prefix, "", 1)
620+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition_obj.module_path, project_root_path)}.{full_name_without_module_prefix}"
614621

615-
if scope_parameters is not None:
622+
for test_func in test_functions_by_name[scope]:
623+
if test_func.parameters is not None:
616624
if test_framework == "pytest":
617-
scope_test_function += "[" + scope_parameters + "]"
618-
if test_framework == "unittest":
619-
scope_test_function += "_" + scope_parameters
620-
621-
full_name_without_module_prefix = definition[0].full_name.replace(
622-
definition[0].module_name + ".", "", 1
623-
)
624-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
625+
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
626+
else: # unittest
627+
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
628+
else:
629+
scope_test_function = test_func.function_name
625630

626631
function_to_test_map[qualified_name_with_modules_from_root].add(
627632
FunctionCalledInTest(
628633
tests_in_file=TestsInFile(
629634
test_file=test_file,
630-
test_class=scope_test_class,
635+
test_class=test_func.test_class,
631636
test_function=scope_test_function,
632-
test_type=test_type,
637+
test_type=test_func.test_type,
633638
),
634639
position=CodePosition(line_no=name.line, col_no=name.column),
635640
)

0 commit comments

Comments
 (0)