99import subprocess
1010import unittest
1111from collections import defaultdict
12+ from concurrent .futures import ProcessPoolExecutor
1213from pathlib import Path
1314from typing import TYPE_CHECKING , Callable , Optional
1415
16+ import jedi
1517import pytest
1618from pydantic .dataclasses import dataclass
1719
@@ -79,8 +81,7 @@ def insert_test(
7981 line_number : int ,
8082 col_number : int ,
8183 ) -> None :
82- self .cur .execute ("DELETE FROM discovered_tests WHERE file_path = ?" , (file_path ,))
83- test_type_value = test_type .value if hasattr (test_type , "value" ) else test_type
84+ assert isinstance (test_type , TestType ), "test_type must be an instance of TestType"
8485 self .cur .execute (
8586 "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
8687 (
@@ -90,7 +91,7 @@ def insert_test(
9091 function_name ,
9192 test_class ,
9293 test_function ,
93- test_type_value ,
94+ test_type . value ,
9495 line_number ,
9596 col_number ,
9697 ),
@@ -277,192 +278,195 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
277278 return False , function_name , None
278279
279280
280- def process_test_files (
281- file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
282- ) -> dict [str , list [FunctionCalledInTest ]]:
283- import jedi
284-
281+ def process_single_test_file (
282+ test_file : Path , functions : list [TestsInFile ], cfg : TestConfig , jedi_project : jedi .Project
283+ ) -> dict [str , set [FunctionCalledInTest ]]:
285284 project_root_path = cfg .project_root_path
286- test_framework = cfg .test_framework
287-
288285 function_to_test_map = defaultdict (set )
289- jedi_project = jedi .Project (path = project_root_path )
290- goto_cache = {}
286+ file_hash = TestsCache .compute_file_hash (test_file )
291287 tests_cache = TestsCache ()
288+ cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
289+ if cached_tests :
290+ self_cur = tests_cache .cur
291+ self_cur .execute (
292+ "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
293+ (str (test_file ), file_hash ),
294+ )
295+ qualified_names = [row [0 ] for row in self_cur .fetchall ()]
296+ for cached , qualified_name in zip (cached_tests , qualified_names ):
297+ function_to_test_map [qualified_name ].add (cached )
298+ tests_cache .close ()
299+ return function_to_test_map
300+ try :
301+ script = jedi .Script (path = test_file , project = jedi_project )
302+ test_functions = set ()
292303
293- with test_files_progress_bar (total = len (file_to_test_map ), description = "Processing test files" ) as (
294- progress ,
295- task_id ,
296- ):
297- for test_file , functions in file_to_test_map .items ():
298- file_hash = TestsCache .compute_file_hash (test_file )
299- cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
300- if cached_tests :
301- self_cur = tests_cache .cur
302- self_cur .execute (
303- "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
304- (str (test_file ), file_hash ),
305- )
306- qualified_names = [row [0 ] for row in self_cur .fetchall ()]
307- for cached , qualified_name in zip (cached_tests , qualified_names ):
308- function_to_test_map [qualified_name ].add (cached )
309- progress .advance (task_id )
310- continue
304+ all_names = script .get_names (all_scopes = True , references = True )
305+ all_names_top = script .get_names (all_scopes = True )
311306
312- try :
313- script = jedi .Script (path = test_file , project = jedi_project )
314- test_functions = set ()
307+ top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
308+ top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
309+ except Exception as e :
310+ logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
311+ tests_cache .close ()
312+ return function_to_test_map
313+
314+ if cfg .test_framework == "pytest" :
315+ for function in functions :
316+ if "[" in function .test_function :
317+ function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[0 ]
318+ parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[1 ]
319+ if function_name in top_level_functions :
320+ test_functions .add (TestFunction (function_name , function .test_class , parameters , function .test_type ))
321+ elif function .test_function in top_level_functions :
322+ test_functions .add (TestFunction (function .test_function , function .test_class , None , function .test_type ))
323+ elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (function .test_function ):
324+ base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub ("" , function .test_function )
325+ if base_name in top_level_functions :
326+ test_functions .add (
327+ TestFunction (
328+ function_name = base_name ,
329+ test_class = function .test_class ,
330+ parameters = function .test_function ,
331+ test_type = function .test_type ,
332+ )
333+ )
334+ elif cfg .test_framework == "unittest" :
335+ all_defs = script .get_names (all_scopes = True , definitions = True )
315336
316- all_names = script .get_names (all_scopes = True , references = True )
317- all_defs = script .get_names (all_scopes = True , definitions = True )
318- all_names_top = script .get_names (all_scopes = True )
337+ functions_to_search = [elem .test_function for elem in functions ]
338+ test_suites = {elem .test_class for elem in functions }
319339
320- top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
321- top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
322- except Exception as e :
323- logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
324- progress .advance (task_id )
325- continue
340+ matching_names = test_suites & top_level_classes .keys ()
341+ for matched_name in matching_names :
342+ for def_name in all_defs :
343+ if (
344+ def_name .type == "function"
345+ and def_name .full_name is not None
346+ and f".{ matched_name } ." in def_name .full_name
347+ ):
348+ for function in functions_to_search :
349+ (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
326350
327- if test_framework == "pytest" :
328- for function in functions :
329- if "[" in function .test_function :
330- function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[0 ]
331- parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[1 ]
332- if function_name in top_level_functions :
351+ if is_parameterized and new_function == def_name .name :
333352 test_functions .add (
334- TestFunction (function_name , function .test_class , parameters , function .test_type )
353+ TestFunction (
354+ function_name = def_name .name ,
355+ test_class = matched_name ,
356+ parameters = parameters ,
357+ test_type = functions [0 ].test_type ,
358+ )
335359 )
336- elif function .test_function in top_level_functions :
337- test_functions .add (
338- TestFunction (function .test_function , function .test_class , None , function .test_type )
339- )
340- elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (function .test_function ):
341- base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub ("" , function .test_function )
342- if base_name in top_level_functions :
360+ elif function == def_name .name :
343361 test_functions .add (
344362 TestFunction (
345- function_name = base_name ,
346- test_class = function . test_class ,
347- parameters = function . test_function ,
348- test_type = function .test_type ,
363+ function_name = def_name . name ,
364+ test_class = matched_name ,
365+ parameters = None ,
366+ test_type = functions [ 0 ] .test_type ,
349367 )
350368 )
351369
352- elif test_framework == "unittest" :
353- functions_to_search = [elem .test_function for elem in functions ]
354- test_suites = {elem .test_class for elem in functions }
355-
356- matching_names = test_suites & top_level_classes .keys ()
357- for matched_name in matching_names :
358- for def_name in all_defs :
359- if (
360- def_name .type == "function"
361- and def_name .full_name is not None
362- and f".{ matched_name } ." in def_name .full_name
363- ):
364- for function in functions_to_search :
365- (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
366-
367- if is_parameterized and new_function == def_name .name :
368- test_functions .add (
369- TestFunction (
370- function_name = def_name .name ,
371- test_class = matched_name ,
372- parameters = parameters ,
373- test_type = functions [0 ].test_type ,
374- )
375- )
376- elif function == def_name .name :
377- test_functions .add (
378- TestFunction (
379- function_name = def_name .name ,
380- test_class = matched_name ,
381- parameters = None ,
382- test_type = functions [0 ].test_type ,
383- )
384- )
385-
386- test_functions_list = list (test_functions )
387- test_functions_raw = [elem .function_name for elem in test_functions_list ]
388-
389- test_functions_by_name = defaultdict (list )
390- for i , func_name in enumerate (test_functions_raw ):
391- test_functions_by_name [func_name ].append (i )
392-
393- for name in all_names :
394- if name .full_name is None :
395- continue
396- m = FUNCTION_NAME_REGEX .search (name .full_name )
397- if not m :
398- continue
399-
400- scope = m .group (1 )
401- if scope not in test_functions_by_name :
402- continue
403-
404- cache_key = (name .full_name , name .module_name )
405- try :
406- if cache_key in goto_cache :
407- definition = goto_cache [cache_key ]
408- else :
409- definition = name .goto (follow_imports = True , follow_builtin_imports = False )
410- goto_cache [cache_key ] = definition
411- except Exception as e :
412- logger .debug (str (e ))
413- continue
414-
415- if not definition or definition [0 ].type != "function" :
416- continue
417-
418- definition_path = str (definition [0 ].module_path )
419- if (
420- definition_path .startswith (str (project_root_path ) + os .sep )
421- and definition [0 ].module_name != name .module_name
422- and definition [0 ].full_name is not None
423- ):
424- for index in test_functions_by_name [scope ]:
425- scope_test_function = test_functions_list [index ].function_name
426- scope_test_class = test_functions_list [index ].test_class
427- scope_parameters = test_functions_list [index ].parameters
428- test_type = test_functions_list [index ].test_type
429-
430- if scope_parameters is not None :
431- if test_framework == "pytest" :
432- scope_test_function += "[" + scope_parameters + "]"
433- if test_framework == "unittest" :
434- scope_test_function += "_" + scope_parameters
435-
436- full_name_without_module_prefix = definition [0 ].full_name .replace (
437- definition [0 ].module_name + "." , "" , 1
438- )
439- qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
370+ test_functions_list = list (test_functions )
371+ test_functions_raw = [elem .function_name for elem in test_functions_list ]
372+
373+ test_functions_by_name = defaultdict (list )
374+ for i , func_name in enumerate (test_functions_raw ):
375+ test_functions_by_name [func_name ].append (i )
376+
377+ for name in all_names :
378+ if name .full_name is None :
379+ continue
380+ m = FUNCTION_NAME_REGEX .search (name .full_name )
381+ if not m :
382+ continue
440383
441- tests_cache .insert_test (
442- file_path = str (test_file ),
443- file_hash = file_hash ,
444- qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
445- function_name = scope ,
384+ scope = m .group (1 )
385+ if scope not in test_functions_by_name :
386+ continue
387+
388+ try :
389+ definition = name .goto (follow_imports = True , follow_builtin_imports = False )
390+ except Exception as e :
391+ logger .debug (str (e ))
392+ continue
393+
394+ if not definition or definition [0 ].type != "function" :
395+ continue
396+
397+ definition_path = str (definition [0 ].module_path )
398+ if (
399+ definition_path .startswith (str (project_root_path ) + os .sep )
400+ and definition [0 ].module_name != name .module_name
401+ and definition [0 ].full_name is not None
402+ ):
403+ for index in test_functions_by_name [scope ]:
404+ scope_test_function = test_functions_list [index ].function_name
405+ scope_test_class = test_functions_list [index ].test_class
406+ scope_parameters = test_functions_list [index ].parameters
407+ test_type = test_functions_list [index ].test_type
408+
409+ if scope_parameters is not None :
410+ if cfg .test_framework == "pytest" :
411+ scope_test_function += "[" + scope_parameters + "]"
412+ if cfg .test_framework == "unittest" :
413+ scope_test_function += "_" + scope_parameters
414+
415+ full_name_without_module_prefix = definition [0 ].full_name .replace (
416+ definition [0 ].module_name + "." , "" , 1
417+ )
418+ qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
419+
420+ tests_cache .insert_test (
421+ file_path = str (test_file ),
422+ file_hash = file_hash ,
423+ qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
424+ function_name = scope ,
425+ test_class = scope_test_class ,
426+ test_function = scope_test_function ,
427+ test_type = test_type ,
428+ line_number = name .line ,
429+ col_number = name .column ,
430+ )
431+
432+ function_to_test_map [qualified_name_with_modules_from_root ].add (
433+ FunctionCalledInTest (
434+ tests_in_file = TestsInFile (
435+ test_file = test_file ,
446436 test_class = scope_test_class ,
447437 test_function = scope_test_function ,
448438 test_type = test_type ,
449- line_number = name .line ,
450- col_number = name .column ,
451- )
439+ ),
440+ position = CodePosition (line_no = name .line , col_no = name .column ),
441+ )
442+ )
452443
453- function_to_test_map [qualified_name_with_modules_from_root ].add (
454- FunctionCalledInTest (
455- tests_in_file = TestsInFile (
456- test_file = test_file ,
457- test_class = scope_test_class ,
458- test_function = scope_test_function ,
459- test_type = test_type ,
460- ),
461- position = CodePosition (line_no = name .line , col_no = name .column ),
462- )
463- )
444+ tests_cache .close ()
445+ return function_to_test_map
464446
465- progress .advance (task_id )
466447
467- tests_cache .close ()
448+ def process_test_files (
449+ file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
450+ ) -> dict [str , list [FunctionCalledInTest ]]:
451+ project_root_path = cfg .project_root_path
452+
453+ function_to_test_map = defaultdict (set )
454+ jedi_project = jedi .Project (path = project_root_path )
455+ with (
456+ test_files_progress_bar (total = len (file_to_test_map ), description = "Processing test files" ) as (
457+ progress ,
458+ task_id ,
459+ ),
460+ ProcessPoolExecutor () as executor ,
461+ ):
462+ futures = {
463+ executor .submit (process_single_test_file , test_file , functions , cfg , jedi_project ): test_file
464+ for test_file , functions in file_to_test_map .items ()
465+ }
466+ for future in futures :
467+ result = future .result ()
468+ for k , v in result .items ():
469+ function_to_test_map [k ].update (v )
470+ progress .update (task_id )
471+
468472 return {function : list (tests ) for function , tests in function_to_test_map .items ()}
0 commit comments