Skip to content

Commit a6a2440

Browse files
committed
Revert "paralelize test discovery"
This reverts commit 9321611.
1 parent 9321611 commit a6a2440

File tree

1 file changed

+169
-173
lines changed

1 file changed

+169
-173
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 169 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
import subprocess
1010
import unittest
1111
from collections import defaultdict
12-
from concurrent.futures import ProcessPoolExecutor
1312
from pathlib import Path
1413
from typing import TYPE_CHECKING, Callable, Optional
1514

16-
import jedi
1715
import pytest
1816
from pydantic.dataclasses import dataclass
1917

@@ -81,7 +79,8 @@ def insert_test(
8179
line_number: int,
8280
col_number: int,
8381
) -> None:
84-
assert isinstance(test_type, TestType), "test_type must be an instance of TestType"
82+
self.cur.execute("DELETE FROM discovered_tests WHERE file_path = ?", (file_path,))
83+
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
8584
self.cur.execute(
8685
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
8786
(
@@ -91,7 +90,7 @@ def insert_test(
9190
function_name,
9291
test_class,
9392
test_function,
94-
test_type.value,
93+
test_type_value,
9594
line_number,
9695
col_number,
9796
),
@@ -278,195 +277,192 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
278277
return False, function_name, None
279278

280279

281-
def process_single_test_file(
282-
test_file: Path, functions: list[TestsInFile], cfg: TestConfig, jedi_project: jedi.Project
283-
) -> dict[str, set[FunctionCalledInTest]]:
280+
def process_test_files(
281+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
282+
) -> dict[str, list[FunctionCalledInTest]]:
283+
import jedi
284+
284285
project_root_path = cfg.project_root_path
286+
test_framework = cfg.test_framework
287+
285288
function_to_test_map = defaultdict(set)
286-
file_hash = TestsCache.compute_file_hash(test_file)
289+
jedi_project = jedi.Project(path=project_root_path)
290+
goto_cache = {}
287291
tests_cache = TestsCache()
288-
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
289-
if cached_tests:
290-
self_cur = tests_cache.cur
291-
self_cur.execute(
292-
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
293-
(str(test_file), file_hash),
294-
)
295-
qualified_names = [row[0] for row in self_cur.fetchall()]
296-
for cached, qualified_name in zip(cached_tests, qualified_names):
297-
function_to_test_map[qualified_name].add(cached)
298-
tests_cache.close()
299-
return function_to_test_map
300-
try:
301-
script = jedi.Script(path=test_file, project=jedi_project)
302-
test_functions = set()
303292

304-
all_names = script.get_names(all_scopes=True, references=True)
305-
all_names_top = script.get_names(all_scopes=True)
293+
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
294+
progress,
295+
task_id,
296+
):
297+
for test_file, functions in file_to_test_map.items():
298+
file_hash = TestsCache.compute_file_hash(test_file)
299+
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
300+
if cached_tests:
301+
self_cur = tests_cache.cur
302+
self_cur.execute(
303+
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
304+
(str(test_file), file_hash),
305+
)
306+
qualified_names = [row[0] for row in self_cur.fetchall()]
307+
for cached, qualified_name in zip(cached_tests, qualified_names):
308+
function_to_test_map[qualified_name].add(cached)
309+
progress.advance(task_id)
310+
continue
306311

307-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
308-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
309-
except Exception as e:
310-
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
311-
tests_cache.close()
312-
return function_to_test_map
313-
314-
if cfg.test_framework == "pytest":
315-
for function in functions:
316-
if "[" in function.test_function:
317-
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
318-
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
319-
if function_name in top_level_functions:
320-
test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type))
321-
elif function.test_function in top_level_functions:
322-
test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type))
323-
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
324-
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
325-
if base_name in top_level_functions:
326-
test_functions.add(
327-
TestFunction(
328-
function_name=base_name,
329-
test_class=function.test_class,
330-
parameters=function.test_function,
331-
test_type=function.test_type,
332-
)
333-
)
334-
elif cfg.test_framework == "unittest":
335-
all_defs = script.get_names(all_scopes=True, definitions=True)
312+
try:
313+
script = jedi.Script(path=test_file, project=jedi_project)
314+
test_functions = set()
336315

337-
functions_to_search = [elem.test_function for elem in functions]
338-
test_suites = {elem.test_class for elem in functions}
316+
all_names = script.get_names(all_scopes=True, references=True)
317+
all_defs = script.get_names(all_scopes=True, definitions=True)
318+
all_names_top = script.get_names(all_scopes=True)
339319

340-
matching_names = test_suites & top_level_classes.keys()
341-
for matched_name in matching_names:
342-
for def_name in all_defs:
343-
if (
344-
def_name.type == "function"
345-
and def_name.full_name is not None
346-
and f".{matched_name}." in def_name.full_name
347-
):
348-
for function in functions_to_search:
349-
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
320+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
321+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
322+
except Exception as e:
323+
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
324+
progress.advance(task_id)
325+
continue
350326

351-
if is_parameterized and new_function == def_name.name:
327+
if test_framework == "pytest":
328+
for function in functions:
329+
if "[" in function.test_function:
330+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
331+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
332+
if function_name in top_level_functions:
352333
test_functions.add(
353-
TestFunction(
354-
function_name=def_name.name,
355-
test_class=matched_name,
356-
parameters=parameters,
357-
test_type=functions[0].test_type,
358-
)
334+
TestFunction(function_name, function.test_class, parameters, function.test_type)
359335
)
360-
elif function == def_name.name:
336+
elif function.test_function in top_level_functions:
337+
test_functions.add(
338+
TestFunction(function.test_function, function.test_class, None, function.test_type)
339+
)
340+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
341+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
342+
if base_name in top_level_functions:
361343
test_functions.add(
362344
TestFunction(
363-
function_name=def_name.name,
364-
test_class=matched_name,
365-
parameters=None,
366-
test_type=functions[0].test_type,
345+
function_name=base_name,
346+
test_class=function.test_class,
347+
parameters=function.test_function,
348+
test_type=function.test_type,
367349
)
368350
)
369351

370-
test_functions_list = list(test_functions)
371-
test_functions_raw = [elem.function_name for elem in test_functions_list]
372-
373-
test_functions_by_name = defaultdict(list)
374-
for i, func_name in enumerate(test_functions_raw):
375-
test_functions_by_name[func_name].append(i)
376-
377-
for name in all_names:
378-
if name.full_name is None:
379-
continue
380-
m = FUNCTION_NAME_REGEX.search(name.full_name)
381-
if not m:
382-
continue
383-
384-
scope = m.group(1)
385-
if scope not in test_functions_by_name:
386-
continue
387-
388-
try:
389-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
390-
except Exception as e:
391-
logger.debug(str(e))
392-
continue
393-
394-
if not definition or definition[0].type != "function":
395-
continue
396-
397-
definition_path = str(definition[0].module_path)
398-
if (
399-
definition_path.startswith(str(project_root_path) + os.sep)
400-
and definition[0].module_name != name.module_name
401-
and definition[0].full_name is not None
402-
):
403-
for index in test_functions_by_name[scope]:
404-
scope_test_function = test_functions_list[index].function_name
405-
scope_test_class = test_functions_list[index].test_class
406-
scope_parameters = test_functions_list[index].parameters
407-
test_type = test_functions_list[index].test_type
408-
409-
if scope_parameters is not None:
410-
if cfg.test_framework == "pytest":
411-
scope_test_function += "[" + scope_parameters + "]"
412-
if cfg.test_framework == "unittest":
413-
scope_test_function += "_" + scope_parameters
414-
415-
full_name_without_module_prefix = definition[0].full_name.replace(
416-
definition[0].module_name + ".", "", 1
417-
)
418-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
419-
420-
tests_cache.insert_test(
421-
file_path=str(test_file),
422-
file_hash=file_hash,
423-
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
424-
function_name=scope,
425-
test_class=scope_test_class,
426-
test_function=scope_test_function,
427-
test_type=test_type,
428-
line_number=name.line,
429-
col_number=name.column,
430-
)
352+
elif test_framework == "unittest":
353+
functions_to_search = [elem.test_function for elem in functions]
354+
test_suites = {elem.test_class for elem in functions}
355+
356+
matching_names = test_suites & top_level_classes.keys()
357+
for matched_name in matching_names:
358+
for def_name in all_defs:
359+
if (
360+
def_name.type == "function"
361+
and def_name.full_name is not None
362+
and f".{matched_name}." in def_name.full_name
363+
):
364+
for function in functions_to_search:
365+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
366+
367+
if is_parameterized and new_function == def_name.name:
368+
test_functions.add(
369+
TestFunction(
370+
function_name=def_name.name,
371+
test_class=matched_name,
372+
parameters=parameters,
373+
test_type=functions[0].test_type,
374+
)
375+
)
376+
elif function == def_name.name:
377+
test_functions.add(
378+
TestFunction(
379+
function_name=def_name.name,
380+
test_class=matched_name,
381+
parameters=None,
382+
test_type=functions[0].test_type,
383+
)
384+
)
385+
386+
test_functions_list = list(test_functions)
387+
test_functions_raw = [elem.function_name for elem in test_functions_list]
388+
389+
test_functions_by_name = defaultdict(list)
390+
for i, func_name in enumerate(test_functions_raw):
391+
test_functions_by_name[func_name].append(i)
392+
393+
for name in all_names:
394+
if name.full_name is None:
395+
continue
396+
m = FUNCTION_NAME_REGEX.search(name.full_name)
397+
if not m:
398+
continue
399+
400+
scope = m.group(1)
401+
if scope not in test_functions_by_name:
402+
continue
403+
404+
cache_key = (name.full_name, name.module_name)
405+
try:
406+
if cache_key in goto_cache:
407+
definition = goto_cache[cache_key]
408+
else:
409+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
410+
goto_cache[cache_key] = definition
411+
except Exception as e:
412+
logger.debug(str(e))
413+
continue
414+
415+
if not definition or definition[0].type != "function":
416+
continue
417+
418+
definition_path = str(definition[0].module_path)
419+
if (
420+
definition_path.startswith(str(project_root_path) + os.sep)
421+
and definition[0].module_name != name.module_name
422+
and definition[0].full_name is not None
423+
):
424+
for index in test_functions_by_name[scope]:
425+
scope_test_function = test_functions_list[index].function_name
426+
scope_test_class = test_functions_list[index].test_class
427+
scope_parameters = test_functions_list[index].parameters
428+
test_type = test_functions_list[index].test_type
429+
430+
if scope_parameters is not None:
431+
if test_framework == "pytest":
432+
scope_test_function += "[" + scope_parameters + "]"
433+
if test_framework == "unittest":
434+
scope_test_function += "_" + scope_parameters
435+
436+
full_name_without_module_prefix = definition[0].full_name.replace(
437+
definition[0].module_name + ".", "", 1
438+
)
439+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
431440

432-
function_to_test_map[qualified_name_with_modules_from_root].add(
433-
FunctionCalledInTest(
434-
tests_in_file=TestsInFile(
435-
test_file=test_file,
441+
tests_cache.insert_test(
442+
file_path=str(test_file),
443+
file_hash=file_hash,
444+
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
445+
function_name=scope,
436446
test_class=scope_test_class,
437447
test_function=scope_test_function,
438448
test_type=test_type,
439-
),
440-
position=CodePosition(line_no=name.line, col_no=name.column),
441-
)
442-
)
443-
444-
tests_cache.close()
445-
return function_to_test_map
446-
449+
line_number=name.line,
450+
col_number=name.column,
451+
)
447452

448-
def process_test_files(
449-
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
450-
) -> dict[str, list[FunctionCalledInTest]]:
451-
project_root_path = cfg.project_root_path
453+
function_to_test_map[qualified_name_with_modules_from_root].add(
454+
FunctionCalledInTest(
455+
tests_in_file=TestsInFile(
456+
test_file=test_file,
457+
test_class=scope_test_class,
458+
test_function=scope_test_function,
459+
test_type=test_type,
460+
),
461+
position=CodePosition(line_no=name.line, col_no=name.column),
462+
)
463+
)
452464

453-
function_to_test_map = defaultdict(set)
454-
jedi_project = jedi.Project(path=project_root_path)
455-
with (
456-
test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
457-
progress,
458-
task_id,
459-
),
460-
ProcessPoolExecutor() as executor,
461-
):
462-
futures = {
463-
executor.submit(process_single_test_file, test_file, functions, cfg, jedi_project): test_file
464-
for test_file, functions in file_to_test_map.items()
465-
}
466-
for future in futures:
467-
result = future.result()
468-
for k, v in result.items():
469-
function_to_test_map[k].update(v)
470-
progress.update(task_id)
465+
progress.advance(task_id)
471466

467+
tests_cache.close()
472468
return {function: list(tests) for function, tests in function_to_test_map.items()}

0 commit comments

Comments
 (0)