Skip to content

Commit 9321611

Browse files
committed
paralelize test discovery
1 parent ae63d4a commit 9321611

File tree

1 file changed

+173
-169
lines changed

1 file changed

+173
-169
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 173 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import subprocess
1010
import unittest
1111
from collections import defaultdict
12+
from concurrent.futures import ProcessPoolExecutor
1213
from pathlib import Path
1314
from typing import TYPE_CHECKING, Callable, Optional
1415

16+
import jedi
1517
import pytest
1618
from pydantic.dataclasses import dataclass
1719

@@ -79,8 +81,7 @@ def insert_test(
7981
line_number: int,
8082
col_number: int,
8183
) -> None:
82-
self.cur.execute("DELETE FROM discovered_tests WHERE file_path = ?", (file_path,))
83-
test_type_value = test_type.value if hasattr(test_type, "value") else test_type
84+
assert isinstance(test_type, TestType), "test_type must be an instance of TestType"
8485
self.cur.execute(
8586
"INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
8687
(
@@ -90,7 +91,7 @@ def insert_test(
9091
function_name,
9192
test_class,
9293
test_function,
93-
test_type_value,
94+
test_type.value,
9495
line_number,
9596
col_number,
9697
),
@@ -277,192 +278,195 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
277278
return False, function_name, None
278279

279280

280-
def process_test_files(
281-
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
282-
) -> dict[str, list[FunctionCalledInTest]]:
283-
import jedi
284-
281+
def process_single_test_file(
282+
test_file: Path, functions: list[TestsInFile], cfg: TestConfig, jedi_project: jedi.Project
283+
) -> dict[str, set[FunctionCalledInTest]]:
285284
project_root_path = cfg.project_root_path
286-
test_framework = cfg.test_framework
287-
288285
function_to_test_map = defaultdict(set)
289-
jedi_project = jedi.Project(path=project_root_path)
290-
goto_cache = {}
286+
file_hash = TestsCache.compute_file_hash(test_file)
291287
tests_cache = TestsCache()
288+
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
289+
if cached_tests:
290+
self_cur = tests_cache.cur
291+
self_cur.execute(
292+
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
293+
(str(test_file), file_hash),
294+
)
295+
qualified_names = [row[0] for row in self_cur.fetchall()]
296+
for cached, qualified_name in zip(cached_tests, qualified_names):
297+
function_to_test_map[qualified_name].add(cached)
298+
tests_cache.close()
299+
return function_to_test_map
300+
try:
301+
script = jedi.Script(path=test_file, project=jedi_project)
302+
test_functions = set()
292303

293-
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
294-
progress,
295-
task_id,
296-
):
297-
for test_file, functions in file_to_test_map.items():
298-
file_hash = TestsCache.compute_file_hash(test_file)
299-
cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash)
300-
if cached_tests:
301-
self_cur = tests_cache.cur
302-
self_cur.execute(
303-
"SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?",
304-
(str(test_file), file_hash),
305-
)
306-
qualified_names = [row[0] for row in self_cur.fetchall()]
307-
for cached, qualified_name in zip(cached_tests, qualified_names):
308-
function_to_test_map[qualified_name].add(cached)
309-
progress.advance(task_id)
310-
continue
304+
all_names = script.get_names(all_scopes=True, references=True)
305+
all_names_top = script.get_names(all_scopes=True)
311306

312-
try:
313-
script = jedi.Script(path=test_file, project=jedi_project)
314-
test_functions = set()
307+
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
308+
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
309+
except Exception as e:
310+
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
311+
tests_cache.close()
312+
return function_to_test_map
313+
314+
if cfg.test_framework == "pytest":
315+
for function in functions:
316+
if "[" in function.test_function:
317+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
318+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
319+
if function_name in top_level_functions:
320+
test_functions.add(TestFunction(function_name, function.test_class, parameters, function.test_type))
321+
elif function.test_function in top_level_functions:
322+
test_functions.add(TestFunction(function.test_function, function.test_class, None, function.test_type))
323+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
324+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
325+
if base_name in top_level_functions:
326+
test_functions.add(
327+
TestFunction(
328+
function_name=base_name,
329+
test_class=function.test_class,
330+
parameters=function.test_function,
331+
test_type=function.test_type,
332+
)
333+
)
334+
elif cfg.test_framework == "unittest":
335+
all_defs = script.get_names(all_scopes=True, definitions=True)
315336

316-
all_names = script.get_names(all_scopes=True, references=True)
317-
all_defs = script.get_names(all_scopes=True, definitions=True)
318-
all_names_top = script.get_names(all_scopes=True)
337+
functions_to_search = [elem.test_function for elem in functions]
338+
test_suites = {elem.test_class for elem in functions}
319339

320-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
321-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
322-
except Exception as e:
323-
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
324-
progress.advance(task_id)
325-
continue
340+
matching_names = test_suites & top_level_classes.keys()
341+
for matched_name in matching_names:
342+
for def_name in all_defs:
343+
if (
344+
def_name.type == "function"
345+
and def_name.full_name is not None
346+
and f".{matched_name}." in def_name.full_name
347+
):
348+
for function in functions_to_search:
349+
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
326350

327-
if test_framework == "pytest":
328-
for function in functions:
329-
if "[" in function.test_function:
330-
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[0]
331-
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(function.test_function)[1]
332-
if function_name in top_level_functions:
351+
if is_parameterized and new_function == def_name.name:
333352
test_functions.add(
334-
TestFunction(function_name, function.test_class, parameters, function.test_type)
353+
TestFunction(
354+
function_name=def_name.name,
355+
test_class=matched_name,
356+
parameters=parameters,
357+
test_type=functions[0].test_type,
358+
)
335359
)
336-
elif function.test_function in top_level_functions:
337-
test_functions.add(
338-
TestFunction(function.test_function, function.test_class, None, function.test_type)
339-
)
340-
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(function.test_function):
341-
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub("", function.test_function)
342-
if base_name in top_level_functions:
360+
elif function == def_name.name:
343361
test_functions.add(
344362
TestFunction(
345-
function_name=base_name,
346-
test_class=function.test_class,
347-
parameters=function.test_function,
348-
test_type=function.test_type,
363+
function_name=def_name.name,
364+
test_class=matched_name,
365+
parameters=None,
366+
test_type=functions[0].test_type,
349367
)
350368
)
351369

352-
elif test_framework == "unittest":
353-
functions_to_search = [elem.test_function for elem in functions]
354-
test_suites = {elem.test_class for elem in functions}
355-
356-
matching_names = test_suites & top_level_classes.keys()
357-
for matched_name in matching_names:
358-
for def_name in all_defs:
359-
if (
360-
def_name.type == "function"
361-
and def_name.full_name is not None
362-
and f".{matched_name}." in def_name.full_name
363-
):
364-
for function in functions_to_search:
365-
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
366-
367-
if is_parameterized and new_function == def_name.name:
368-
test_functions.add(
369-
TestFunction(
370-
function_name=def_name.name,
371-
test_class=matched_name,
372-
parameters=parameters,
373-
test_type=functions[0].test_type,
374-
)
375-
)
376-
elif function == def_name.name:
377-
test_functions.add(
378-
TestFunction(
379-
function_name=def_name.name,
380-
test_class=matched_name,
381-
parameters=None,
382-
test_type=functions[0].test_type,
383-
)
384-
)
385-
386-
test_functions_list = list(test_functions)
387-
test_functions_raw = [elem.function_name for elem in test_functions_list]
388-
389-
test_functions_by_name = defaultdict(list)
390-
for i, func_name in enumerate(test_functions_raw):
391-
test_functions_by_name[func_name].append(i)
392-
393-
for name in all_names:
394-
if name.full_name is None:
395-
continue
396-
m = FUNCTION_NAME_REGEX.search(name.full_name)
397-
if not m:
398-
continue
399-
400-
scope = m.group(1)
401-
if scope not in test_functions_by_name:
402-
continue
403-
404-
cache_key = (name.full_name, name.module_name)
405-
try:
406-
if cache_key in goto_cache:
407-
definition = goto_cache[cache_key]
408-
else:
409-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
410-
goto_cache[cache_key] = definition
411-
except Exception as e:
412-
logger.debug(str(e))
413-
continue
414-
415-
if not definition or definition[0].type != "function":
416-
continue
417-
418-
definition_path = str(definition[0].module_path)
419-
if (
420-
definition_path.startswith(str(project_root_path) + os.sep)
421-
and definition[0].module_name != name.module_name
422-
and definition[0].full_name is not None
423-
):
424-
for index in test_functions_by_name[scope]:
425-
scope_test_function = test_functions_list[index].function_name
426-
scope_test_class = test_functions_list[index].test_class
427-
scope_parameters = test_functions_list[index].parameters
428-
test_type = test_functions_list[index].test_type
429-
430-
if scope_parameters is not None:
431-
if test_framework == "pytest":
432-
scope_test_function += "[" + scope_parameters + "]"
433-
if test_framework == "unittest":
434-
scope_test_function += "_" + scope_parameters
435-
436-
full_name_without_module_prefix = definition[0].full_name.replace(
437-
definition[0].module_name + ".", "", 1
438-
)
439-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
370+
test_functions_list = list(test_functions)
371+
test_functions_raw = [elem.function_name for elem in test_functions_list]
372+
373+
test_functions_by_name = defaultdict(list)
374+
for i, func_name in enumerate(test_functions_raw):
375+
test_functions_by_name[func_name].append(i)
376+
377+
for name in all_names:
378+
if name.full_name is None:
379+
continue
380+
m = FUNCTION_NAME_REGEX.search(name.full_name)
381+
if not m:
382+
continue
440383

441-
tests_cache.insert_test(
442-
file_path=str(test_file),
443-
file_hash=file_hash,
444-
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
445-
function_name=scope,
384+
scope = m.group(1)
385+
if scope not in test_functions_by_name:
386+
continue
387+
388+
try:
389+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
390+
except Exception as e:
391+
logger.debug(str(e))
392+
continue
393+
394+
if not definition or definition[0].type != "function":
395+
continue
396+
397+
definition_path = str(definition[0].module_path)
398+
if (
399+
definition_path.startswith(str(project_root_path) + os.sep)
400+
and definition[0].module_name != name.module_name
401+
and definition[0].full_name is not None
402+
):
403+
for index in test_functions_by_name[scope]:
404+
scope_test_function = test_functions_list[index].function_name
405+
scope_test_class = test_functions_list[index].test_class
406+
scope_parameters = test_functions_list[index].parameters
407+
test_type = test_functions_list[index].test_type
408+
409+
if scope_parameters is not None:
410+
if cfg.test_framework == "pytest":
411+
scope_test_function += "[" + scope_parameters + "]"
412+
if cfg.test_framework == "unittest":
413+
scope_test_function += "_" + scope_parameters
414+
415+
full_name_without_module_prefix = definition[0].full_name.replace(
416+
definition[0].module_name + ".", "", 1
417+
)
418+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
419+
420+
tests_cache.insert_test(
421+
file_path=str(test_file),
422+
file_hash=file_hash,
423+
qualified_name_with_modules_from_root=qualified_name_with_modules_from_root,
424+
function_name=scope,
425+
test_class=scope_test_class,
426+
test_function=scope_test_function,
427+
test_type=test_type,
428+
line_number=name.line,
429+
col_number=name.column,
430+
)
431+
432+
function_to_test_map[qualified_name_with_modules_from_root].add(
433+
FunctionCalledInTest(
434+
tests_in_file=TestsInFile(
435+
test_file=test_file,
446436
test_class=scope_test_class,
447437
test_function=scope_test_function,
448438
test_type=test_type,
449-
line_number=name.line,
450-
col_number=name.column,
451-
)
439+
),
440+
position=CodePosition(line_no=name.line, col_no=name.column),
441+
)
442+
)
452443

453-
function_to_test_map[qualified_name_with_modules_from_root].add(
454-
FunctionCalledInTest(
455-
tests_in_file=TestsInFile(
456-
test_file=test_file,
457-
test_class=scope_test_class,
458-
test_function=scope_test_function,
459-
test_type=test_type,
460-
),
461-
position=CodePosition(line_no=name.line, col_no=name.column),
462-
)
463-
)
444+
tests_cache.close()
445+
return function_to_test_map
464446

465-
progress.advance(task_id)
466447

467-
tests_cache.close()
448+
def process_test_files(
449+
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
450+
) -> dict[str, list[FunctionCalledInTest]]:
451+
project_root_path = cfg.project_root_path
452+
453+
function_to_test_map = defaultdict(set)
454+
jedi_project = jedi.Project(path=project_root_path)
455+
with (
456+
test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
457+
progress,
458+
task_id,
459+
),
460+
ProcessPoolExecutor() as executor,
461+
):
462+
futures = {
463+
executor.submit(process_single_test_file, test_file, functions, cfg, jedi_project): test_file
464+
for test_file, functions in file_to_test_map.items()
465+
}
466+
for future in futures:
467+
result = future.result()
468+
for k, v in result.items():
469+
function_to_test_map[k].update(v)
470+
progress.update(task_id)
471+
468472
return {function: list(tests) for function, tests in function_to_test_map.items()}

0 commit comments

Comments
 (0)