Skip to content

Commit 824ce4c

Browse files
committed
remove cache
1 parent c6139c3 commit 824ce4c

File tree

1 file changed

+171
-169
lines changed

1 file changed

+171
-169
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 171 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import subprocess
1010
import unittest
1111
from collections import defaultdict
12+
from concurrent.futures import ProcessPoolExecutor, as_completed
1213
from pathlib import Path
1314
from typing import TYPE_CHECKING, Callable, Optional
1415

@@ -276,192 +277,193 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
276277
return False, function_name, None
277278

278279

279-
def process_test_files(
280-
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
281-
) -> dict[str, list[FunctionCalledInTest]]:
280+
def _process_single_test_file(
281+
test_file: Path, functions: list[TestsInFile], project_root_path: Path, test_framework: str
282+
) -> tuple[str, list[tuple[str, FunctionCalledInTest]]]:
282283
import jedi
283284

284-
project_root_path = cfg.project_root_path
285-
test_framework = cfg.test_framework
286-
287-
function_to_test_map = defaultdict(set)
288285
jedi_project = jedi.Project(path=project_root_path)
289286
goto_cache = {}
290-
tests_cache = TestsCache()
287+
results = []
291288

292-
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
293-
progress,
294-
task_id,
295-
):
296-
for test_file, functions in file_to_test_map.items():
297-
file_hash = TestsCache.compute_file_hash(test_file)
298-
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
299-
if cached_tests:
300-
self_cur = tests_cache.cur
301-
self_cur.execute(
302-
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
303-
(str(test_file), file_hash),
304-
)
305-
qualified_names = [row[0] for row in self_cur.fetchall()]
306-
for cached, qualified_name in zip(cached_tests, qualified_names):
307-
function_to_test_map[qualified_name].add(cached)
308-
progress.advance(task_id)
309-
continue
289+
try:
290+
script = jedi.Script(path=test_file, project=jedi_project)
291+
test_functions = set()
310292

311-
try:
312-
script = jedi.Script(path=test_file, project=jedi_project)
313-
test_functions = set()
293+
all_names = script.get_names(all_scopes=True, references=True)
294+
all_defs = script.get_names(all_scopes=True, definitions=True)
295+
all_names_top = script.get_names(all_scopes=True)
314296

315-
all_names = script.get_names(all_scopes=True, references=True)
316-
all_defs = script.get_names(all_scopes=True, definitions=True)
317-
all_names_top = script.get_names(all_scopes=True)
297+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
298+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
299+
except Exception as e:
300+
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
301+
# tests_cache.close()
302+
return str(test_file), results
303+
304+
if test_framework == "pytest":
305+
for function in functions:
306+
if "[" in function.test_function:
307+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
308+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
309+
if function_name in top_level_functions:
310+
test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type))
311+
elif function.test_function in top_level_functions:
312+
test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type))
313+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
314+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
315+
if base_name in top_level_functions:
316+
test_functions.add(
317+
TestFunction(
318+
function_name=base_name,
319+
test_class=function.test_class,
320+
parameters=function.test_function,
321+
test_type=function.test_type,
322+
)
323+
)
318324

319-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
320-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
321-
except Exception as e:
322-
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
323-
progress.advance(task_id)
324-
continue
325+
elif test_framework == "unittest":
326+
functions_to_search = [elem.test_function for elem in functions]
327+
test_suites = {elem.test_class for elem in functions}
328+
329+
matching_names = test_suites & top_level_classes.keys()
330+
for matched_name in matching_names:
331+
for def_name in all_defs:
332+
if (
333+
def_name.type == "function"
334+
and def_name.full_name is not None
335+
and f".{matched_name}." in def_name.full_name
336+
):
337+
for function in functions_to_search:
338+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
325339

326-
if test_framework == "pytest":
327-
for function in functions:
328-
if "[" in function.test_function:
329-
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
330-
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
331-
if function_name in top_level_functions:
340+
if is_parameterized and new_function == def_name.name:
332341
test_functions.add(
333-
TestFunction(function_name, function.test_class, parameters, function.test_type)
342+
TestFunction(
343+
function_name=def_name.name,
344+
test_class=matched_name,
345+
parameters=parameters,
346+
test_type=functions[0].test_type,
347+
)
334348
)
335-
elif function.test_function in top_level_functions:
336-
test_functions.add(
337-
TestFunction(function.test_function, function.test_class, None, function.test_type)
338-
)
339-
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
340-
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
341-
if base_name in top_level_functions:
349+
elif function == def_name.name:
342350
test_functions.add(
343351
TestFunction(
344-
function_name=base_name,
345-
test_class=function.test_class,
346-
parameters=function.test_function,
347-
test_type=function.test_type,
352+
function_name=def_name.name,
353+
test_class=matched_name,
354+
parameters=None,
355+
test_type=functions[0].test_type,
348356
)
349357
)
350358

351-
elif test_framework == "unittest":
352-
functions_to_search = [elem.test_function for elem in functions]
353-
test_suites = {elem.test_class for elem in functions}
354-
355-
matching_names = test_suites & top_level_classes.keys()
356-
for matched_name in matching_names:
357-
for def_name in all_defs:
358-
if (
359-
def_name.type == "function"
360-
and def_name.full_name is not None
361-
and f".{matched_name}." in def_name.full_name
362-
):
363-
for function in functions_to_search:
364-
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
365-
366-
if is_parameterized and new_function == def_name.name:
367-
test_functions.add(
368-
TestFunction(
369-
function_name=def_name.name,
370-
test_class=matched_name,
371-
parameters=parameters,
372-
test_type=functions[0].test_type,
373-
)
374-
)
375-
elif function == def_name.name:
376-
test_functions.add(
377-
TestFunction(
378-
function_name=def_name.name,
379-
test_class=matched_name,
380-
parameters=None,
381-
test_type=functions[0].test_type,
382-
)
383-
)
384-
385-
test_functions_list = list(test_functions)
386-
test_functions_raw = [elem.function_name for elem in test_functions_list]
387-
388-
test_functions_by_name = defaultdict(list)
389-
for i, func_name in enumerate(test_functions_raw):
390-
test_functions_by_name[func_name].append(i)
391-
392-
for name in all_names:
393-
if name.full_name is None:
394-
continue
395-
m = FUNCTION_NAME_REGEX.search(name.full_name)
396-
if not m:
397-
continue
398-
399-
scope = m.group(1)
400-
if scope not in test_functions_by_name:
401-
continue
402-
403-
cache_key = (name.full_name, name.module_name)
404-
try:
405-
if cache_key in goto_cache:
406-
definition = goto_cache[cache_key]
407-
else:
408-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
409-
goto_cache[cache_key] = definition
410-
except Exception as e:
411-
logger.debug(str(e))
412-
continue
413-
414-
if not definition or definition[0].type != "function":
415-
continue
416-
417-
definition_path = str(definition[0].module_path)
418-
if (
419-
definition_path.startswith(str(project_root_path) + os.sep)
420-
and definition[0].module_name != name.module_name
421-
and definition[0].full_name is not None
422-
):
423-
for index in test_functions_by_name[scope]:
424-
scope_test_function = test_functions_list[index].function_name
425-
scope_test_class = test_functions_list[index].test_class
426-
scope_parameters = test_functions_list[index].parameters
427-
test_type = test_functions_list[index].test_type
428-
429-
if scope_parameters is not None:
430-
if test_framework == "pytest":
431-
scope_test_function += "[" + scope_parameters + "]"
432-
if test_framework == "unittest":
433-
scope_test_function += "_" + scope_parameters
434-
435-
full_name_without_module_prefix = definition[0].full_name.replace(
436-
definition[0].module_name + ".", "", 1
437-
)
438-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
439-
440-
tests_cache.insert_test(
441-
file_path=str(test_file),
442-
file_hash=file_hash,
443-
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
444-
function_name=scope,
445-
test_class=scope_test_class,
446-
test_function=scope_test_function,
447-
test_type=test_type,
448-
line_number=name.line,
449-
col_number=name.column,
450-
)
359+
test_functions_list = list(test_functions)
360+
test_functions_raw = [elem.function_name for elem in test_functions_list]
451361

452-
function_to_test_map[qualified_name_with_modules_from_root].add(
453-
FunctionCalledInTest(
454-
tests_in_file=TestsInFile(
455-
test_file=test_file,
456-
test_class=scope_test_class,
457-
test_function=scope_test_function,
458-
test_type=test_type,
459-
),
460-
position=CodePosition(line_no=name.line, col_no=name.column),
461-
)
462-
)
362+
test_functions_by_name = defaultdict(list)
363+
for i, func_name in enumerate(test_functions_raw):
364+
test_functions_by_name[func_name].append(i)
463365

464-
progress.advance(task_id)
366+
for name in all_names:
367+
if name.full_name is None:
368+
continue
369+
m = FUNCTION_NAME_REGEX.search(name.full_name)
370+
if not m:
371+
continue
372+
373+
scope = m.group(1)
374+
if scope not in test_functions_by_name:
375+
continue
376+
377+
cache_key = (name.full_name, name.module_name)
378+
try:
379+
if cache_key in goto_cache:
380+
definition = goto_cache[cache_key]
381+
else:
382+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
383+
goto_cache[cache_key] = definition
384+
except Exception as e:
385+
logger.debug(str(e))
386+
continue
387+
388+
if not definition or definition[0].type != "function":
389+
continue
390+
391+
definition_path = str(definition[0].module_path)
392+
if (
393+
definition_path.startswith(str(project_root_path) + os.sep)
394+
and definition[0].module_name != name.module_name
395+
and definition[0].full_name is not None
396+
):
397+
for index in test_functions_by_name[scope]:
398+
scope_test_function = test_functions_list[index].function_name
399+
scope_test_class = test_functions_list[index].test_class
400+
scope_parameters = test_functions_list[index].parameters
401+
test_type = test_functions_list[index].test_type
402+
403+
if scope_parameters is not None:
404+
if test_framework == "pytest":
405+
scope_test_function += "[" + scope_parameters + "]"
406+
if test_framework == "unittest":
407+
scope_test_function += "_" + scope_parameters
408+
409+
full_name_without_module_prefix = definition[0].full_name.replace(
410+
definition[0].module_name + ".", "", 1
411+
)
412+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
413+
function_called_in_test = FunctionCalledInTest(
414+
tests_in_file=TestsInFile(
415+
test_file=test_file,
416+
test_class=scope_test_class,
417+
test_function=scope_test_function,
418+
test_type=test_type,
419+
),
420+
position=CodePosition(line_no=name.line, col_no=name.column),
421+
)
422+
results.append((qualified_name_with_modules_from_root, function_called_in_test))
423+
424+
return str(test_file), results
425+
426+
427+
def process_test_files(
428+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
429+
) -> dict[str, list[FunctionCalledInTest]]:
430+
project_root_path = cfg.project_root_path
431+
test_framework = cfg.test_framework
432+
function_to_test_map = defaultdict(set)
433+
434+
import multiprocessing
435+
436+
max_workers = min(len(file_to_test_map), multiprocessing.cpu_count())
437+
max_workers = max(1, max_workers)
438+
439+
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
440+
progress,
441+
task_id,
442+
):
443+
if len(file_to_test_map) == 1 or max_workers == 1:
444+
for test_file, functions in file_to_test_map.items():
445+
_, results = _process_single_test_file(test_file, functions, project_root_path, test_framework)
446+
for qualified_name, function_called in results:
447+
function_to_test_map[qualified_name].add(function_called)
448+
progress.advance(task_id)
449+
else:
450+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
451+
future_to_file = {
452+
executor.submit(
453+
_process_single_test_file, test_file, functions, project_root_path, test_framework
454+
): test_file
455+
for test_file, functions in file_to_test_map.items()
456+
}
457+
458+
for future in as_completed(future_to_file):
459+
try:
460+
_, results = future.result()
461+
for qualified_name, function_called in results:
462+
function_to_test_map[qualified_name].add(function_called)
463+
progress.advance(task_id)
464+
except Exception as e: # noqa: PERF203
465+
test_file = future_to_file[future]
466+
logger.error(f"Error processing test file {test_file}: {e}")
467+
progress.advance(task_id)
465468

466-
tests_cache.close()
467469
return {function: list(tests) for function, tests in function_to_test_map.items()}

0 commit comments

Comments
 (0)