Skip to content

Commit a81bb64

Browse files
committed
extract into standalone func
1 parent 4de9323 commit a81bb64

File tree

1 file changed

+169
-164
lines changed

1 file changed

+169
-164
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 169 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313
from typing import TYPE_CHECKING, Callable, Optional
1414

15+
import jedi
1516
import pytest
1617
from pydantic.dataclasses import dataclass
1718
from rich.panel import Panel
@@ -288,191 +289,195 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
288289
return False, function_name, None
289290

290291

291-
def process_test_files(
292-
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
293-
) -> dict[str, list[FunctionCalledInTest]]:
294-
import jedi
295-
292+
def _process_single_test_file(
293+
test_file: Path,
294+
functions: list[TestsInFile],
295+
cfg: TestConfig,
296+
jedi_project: jedi.Project,
297+
goto_cache: dict,
298+
tests_cache: TestsCache,
299+
function_to_test_map: defaultdict,
300+
) -> None:
296301
project_root_path = cfg.project_root_path
297302
test_framework = cfg.test_framework
303+
file_hash = TestsCache.compute_file_hash(test_file)
304+
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
305+
if cached_tests:
306+
self_cur = tests_cache.cur
307+
self_cur.execute(
308+
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
309+
(str(test_file), file_hash),
310+
)
311+
qualified_names = [row[0] for row in self_cur.fetchall()]
312+
for cached, qualified_name in zip(cached_tests, qualified_names):
313+
function_to_test_map[qualified_name].add(cached)
314+
return
298315

299-
function_to_test_map = defaultdict(set)
300-
jedi_project = jedi.Project(path=project_root_path)
301-
goto_cache = {}
302-
tests_cache = TestsCache()
316+
try:
317+
script = jedi.Script(path=test_file, project=jedi_project)
318+
test_functions = set()
303319

304-
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
305-
progress,
306-
task_id,
307-
):
308-
for test_file, functions in file_to_test_map.items():
309-
file_hash = TestsCache.compute_file_hash(test_file)
310-
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
311-
if cached_tests:
312-
self_cur = tests_cache.cur
313-
self_cur.execute(
314-
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
315-
(str(test_file), file_hash),
316-
)
317-
qualified_names = [row[0] for row in self_cur.fetchall()]
318-
for cached, qualified_name in zip(cached_tests, qualified_names):
319-
function_to_test_map[qualified_name].add(cached)
320-
progress.advance(task_id)
321-
continue
320+
all_names = script.get_names(all_scopes=True, references=True)
321+
all_defs = script.get_names(all_scopes=True, definitions=True)
322+
all_names_top = script.get_names(all_scopes=True)
322323

323-
try:
324-
script = jedi.Script(path=test_file, project=jedi_project)
325-
test_functions = set()
324+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
325+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
326+
except Exception as e:
327+
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
328+
return
329+
330+
if test_framework == "pytest":
331+
for function in functions:
332+
if "[" in function.test_function:
333+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
334+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
335+
if function_name in top_level_functions:
336+
test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type))
337+
elif function.test_function in top_level_functions:
338+
test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type))
339+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
340+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
341+
if base_name in top_level_functions:
342+
test_functions.add(
343+
TestFunction(
344+
function_name=base_name,
345+
test_class=function.test_class,
346+
parameters=function.test_function,
347+
test_type=function.test_type,
348+
)
349+
)
326350

327-
all_names = script.get_names(all_scopes=True, references=True)
328-
all_defs = script.get_names(all_scopes=True, definitions=True)
329-
all_names_top = script.get_names(all_scopes=True)
351+
elif test_framework == "unittest":
352+
functions_to_search = [elem.test_function for elem in functions]
353+
test_suites = {elem.test_class for elem in functions}
330354

331-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
332-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
333-
except Exception as e:
334-
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
335-
progress.advance(task_id)
336-
continue
355+
matching_names = test_suites & top_level_classes.keys()
356+
for matched_name in matching_names:
357+
for def_name in all_defs:
358+
if (
359+
def_name.type == "function"
360+
and def_name.full_name is not None
361+
and f".{matched_name}." in def_name.full_name
362+
):
363+
for function in functions_to_search:
364+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
337365

338-
if test_framework == "pytest":
339-
for function in functions:
340-
if "[" in function.test_function:
341-
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
342-
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
343-
if function_name in top_level_functions:
366+
if is_parameterized and new_function == def_name.name:
344367
test_functions.add(
345-
TestFunction(function_name, function.test_class, parameters, function.test_type)
368+
TestFunction(
369+
function_name=def_name.name,
370+
test_class=matched_name,
371+
parameters=parameters,
372+
test_type=functions[0].test_type,
373+
)
346374
)
347-
elif function.test_function in top_level_functions:
348-
test_functions.add(
349-
TestFunction(function.test_function, function.test_class, None, function.test_type)
350-
)
351-
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
352-
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
353-
if base_name in top_level_functions:
375+
elif function == def_name.name:
354376
test_functions.add(
355377
TestFunction(
356-
function_name=base_name,
357-
test_class=function.test_class,
358-
parameters=function.test_function,
359-
test_type=function.test_type,
378+
function_name=def_name.name,
379+
test_class=matched_name,
380+
parameters=None,
381+
test_type=functions[0].test_type,
360382
)
361383
)
362384

363-
elif test_framework == "unittest":
364-
functions_to_search = [elem.test_function for elem in functions]
365-
test_suites = {elem.test_class for elem in functions}
366-
367-
matching_names = test_suites & top_level_classes.keys()
368-
for matched_name in matching_names:
369-
for def_name in all_defs:
370-
if (
371-
def_name.type == "function"
372-
and def_name.full_name is not None
373-
and f".{matched_name}." in def_name.full_name
374-
):
375-
for function in functions_to_search:
376-
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
377-
378-
if is_parameterized and new_function == def_name.name:
379-
test_functions.add(
380-
TestFunction(
381-
function_name=def_name.name,
382-
test_class=matched_name,
383-
parameters=parameters,
384-
test_type=functions[0].test_type,
385-
)
386-
)
387-
elif function == def_name.name:
388-
test_functions.add(
389-
TestFunction(
390-
function_name=def_name.name,
391-
test_class=matched_name,
392-
parameters=None,
393-
test_type=functions[0].test_type,
394-
)
395-
)
396-
397-
test_functions_list = list(test_functions)
398-
test_functions_raw = [elem.function_name for elem in test_functions_list]
399-
400-
test_functions_by_name = defaultdict(list)
401-
for i, func_name in enumerate(test_functions_raw):
402-
test_functions_by_name[func_name].append(i)
403-
404-
for name in all_names:
405-
if name.full_name is None:
406-
continue
407-
m = FUNCTION_NAME_REGEX.search(name.full_name)
408-
if not m:
409-
continue
410-
411-
scope = m.group(1)
412-
if scope not in test_functions_by_name:
413-
continue
414-
415-
cache_key = (name.full_name, name.module_name)
416-
try:
417-
if cache_key in goto_cache:
418-
definition = goto_cache[cache_key]
419-
else:
420-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
421-
goto_cache[cache_key] = definition
422-
except Exception as e:
423-
logger.debug(str(e))
424-
continue
425-
426-
if not definition or definition[0].type != "function":
427-
continue
428-
429-
definition_path = str(definition[0].module_path)
430-
if (
431-
definition_path.startswith(str(project_root_path) + os.sep)
432-
and definition[0].module_name != name.module_name
433-
and definition[0].full_name is not None
434-
):
435-
for index in test_functions_by_name[scope]:
436-
scope_test_function = test_functions_list[index].function_name
437-
scope_test_class = test_functions_list[index].test_class
438-
scope_parameters = test_functions_list[index].parameters
439-
test_type = test_functions_list[index].test_type
440-
441-
if scope_parameters is not None:
442-
if test_framework == "pytest":
443-
scope_test_function += "[" + scope_parameters + "]"
444-
if test_framework == "unittest":
445-
scope_test_function += "_" + scope_parameters
446-
447-
full_name_without_module_prefix = definition[0].full_name.replace(
448-
definition[0].module_name + ".", "", 1
449-
)
450-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
385+
test_functions_list = list(test_functions)
386+
test_functions_raw = [elem.function_name for elem in test_functions_list]
387+
388+
test_functions_by_name = defaultdict(list)
389+
for i, func_name in enumerate(test_functions_raw):
390+
test_functions_by_name[func_name].append(i)
391+
392+
for name in all_names:
393+
if name.full_name is None:
394+
continue
395+
m = FUNCTION_NAME_REGEX.search(name.full_name)
396+
if not m:
397+
continue
398+
399+
scope = m.group(1)
400+
if scope not in test_functions_by_name:
401+
continue
402+
403+
cache_key = (name.full_name, name.module_name)
404+
try:
405+
if cache_key in goto_cache:
406+
definition = goto_cache[cache_key]
407+
else:
408+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
409+
goto_cache[cache_key] = definition
410+
except Exception as e:
411+
logger.debug(str(e))
412+
continue
413+
414+
if not definition or definition[0].type != "function":
415+
continue
451416

452-
tests_cache.insert_test(
453-
file_path=str(test_file),
454-
file_hash=file_hash,
455-
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
456-
function_name=scope,
417+
definition_path = str(definition[0].module_path)
418+
if (
419+
definition_path.startswith(str(project_root_path) + os.sep)
420+
and definition[0].module_name != name.module_name
421+
and definition[0].full_name is not None
422+
):
423+
for index in test_functions_by_name[scope]:
424+
scope_test_function = test_functions_list[index].function_name
425+
scope_test_class = test_functions_list[index].test_class
426+
scope_parameters = test_functions_list[index].parameters
427+
test_type = test_functions_list[index].test_type
428+
429+
if scope_parameters is not None:
430+
if test_framework == "pytest":
431+
scope_test_function += "[" + scope_parameters + "]"
432+
if test_framework == "unittest":
433+
scope_test_function += "_" + scope_parameters
434+
435+
full_name_without_module_prefix = definition[0].full_name.replace(
436+
definition[0].module_name + ".", "", 1
437+
)
438+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
439+
440+
tests_cache.insert_test(
441+
file_path=str(test_file),
442+
file_hash=file_hash,
443+
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
444+
function_name=scope,
445+
test_class=scope_test_class,
446+
test_function=scope_test_function,
447+
test_type=test_type,
448+
line_number=name.line,
449+
col_number=name.column,
450+
)
451+
452+
function_to_test_map[qualified_name_with_modules_from_root].add(
453+
FunctionCalledInTest(
454+
tests_in_file=TestsInFile(
455+
test_file=test_file,
457456
test_class=scope_test_class,
458457
test_function=scope_test_function,
459458
test_type=test_type,
460-
line_number=name.line,
461-
col_number=name.column,
462-
)
459+
),
460+
position=CodePosition(line_no=name.line, col_no=name.column),
461+
)
462+
)
463463

464-
function_to_test_map[qualified_name_with_modules_from_root].add(
465-
FunctionCalledInTest(
466-
tests_in_file=TestsInFile(
467-
test_file=test_file,
468-
test_class=scope_test_class,
469-
test_function=scope_test_function,
470-
test_type=test_type,
471-
),
472-
position=CodePosition(line_no=name.line, col_no=name.column),
473-
)
474-
)
475464

465+
def process_test_files(
466+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
467+
) -> dict[str, list[FunctionCalledInTest]]:
468+
function_to_test_map = defaultdict(set)
469+
jedi_project = jedi.Project(path=cfg.project_root_path)
470+
goto_cache = {}
471+
tests_cache = TestsCache()
472+
473+
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
474+
progress,
475+
task_id,
476+
):
477+
for test_file, functions in file_to_test_map.items():
478+
_process_single_test_file(
479+
test_file, functions, cfg, jedi_project, goto_cache, tests_cache, function_to_test_map
480+
)
476481
progress.advance(task_id)
477482

478483
tests_cache.close()

0 commit comments

Comments
 (0)