99import subprocess
1010import unittest
1111from collections import defaultdict
12- from concurrent .futures import ProcessPoolExecutor
1312from pathlib import Path
1413from typing import TYPE_CHECKING , Callable , Optional
1514
16- import jedi
1715import pytest
1816from pydantic .dataclasses import dataclass
1917
@@ -81,7 +79,8 @@ def insert_test(
8179 line_number : int ,
8280 col_number : int ,
8381 ) -> None :
84- assert isinstance (test_type , TestType ), "test_type must be an instance of TestType"
82+ self .cur .execute ("DELETE FROM discovered_tests WHERE file_path = ?" , (file_path ,))
83+ test_type_value = test_type .value if hasattr (test_type , "value" ) else test_type
8584 self .cur .execute (
8685 "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
8786 (
@@ -91,7 +90,7 @@ def insert_test(
9190 function_name ,
9291 test_class ,
9392 test_function ,
94- test_type . value ,
93+ test_type_value ,
9594 line_number ,
9695 col_number ,
9796 ),
@@ -278,195 +277,192 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
278277 return False , function_name , None
279278
280279
281- def process_single_test_file (
282- test_file : Path , functions : list [TestsInFile ], cfg : TestConfig , jedi_project : jedi .Project
283- ) -> dict [str , set [FunctionCalledInTest ]]:
280+ def process_test_files (
281+ file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
282+ ) -> dict [str , list [FunctionCalledInTest ]]:
283+ import jedi
284+
284285 project_root_path = cfg .project_root_path
286+ test_framework = cfg .test_framework
287+
285288 function_to_test_map = defaultdict (set )
286- file_hash = TestsCache .compute_file_hash (test_file )
289+ jedi_project = jedi .Project (path = project_root_path )
290+ goto_cache = {}
287291 tests_cache = TestsCache ()
288- cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
289- if cached_tests :
290- self_cur = tests_cache .cur
291- self_cur .execute (
292- "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
293- (str (test_file ), file_hash ),
294- )
295- qualified_names = [row [0 ] for row in self_cur .fetchall ()]
296- for cached , qualified_name in zip (cached_tests , qualified_names ):
297- function_to_test_map [qualified_name ].add (cached )
298- tests_cache .close ()
299- return function_to_test_map
300- try :
301- script = jedi .Script (path = test_file , project = jedi_project )
302- test_functions = set ()
303292
304- all_names = script .get_names (all_scopes = True , references = True )
305- all_names_top = script .get_names (all_scopes = True )
293+ with test_files_progress_bar (total = len (file_to_test_map ), description = "Processing test files" ) as (
294+ progress ,
295+ task_id ,
296+ ):
297+ for test_file , functions in file_to_test_map .items ():
298+ file_hash = TestsCache .compute_file_hash (test_file )
299+ cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
300+ if cached_tests :
301+ self_cur = tests_cache .cur
302+ self_cur .execute (
303+ "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
304+ (str (test_file ), file_hash ),
305+ )
306+ qualified_names = [row [0 ] for row in self_cur .fetchall ()]
307+ for cached , qualified_name in zip (cached_tests , qualified_names ):
308+ function_to_test_map [qualified_name ].add (cached )
309+ progress .advance (task_id )
310+ continue
306311
307- top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
308- top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
309- except Exception as e :
310- logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
311- tests_cache .close ()
312- return function_to_test_map
313-
314- if cfg .test_framework == "pytest" :
315- for function in functions :
316- if "[" in function .test_function :
317- function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[0 ]
318- parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[1 ]
319- if function_name in top_level_functions :
320- test_functions .add (TestFunction (function_name , function .test_class , parameters , function .test_type ))
321- elif function .test_function in top_level_functions :
322- test_functions .add (TestFunction (function .test_function , function .test_class , None , function .test_type ))
323- elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (function .test_function ):
324- base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub ("" , function .test_function )
325- if base_name in top_level_functions :
326- test_functions .add (
327- TestFunction (
328- function_name = base_name ,
329- test_class = function .test_class ,
330- parameters = function .test_function ,
331- test_type = function .test_type ,
332- )
333- )
334- elif cfg .test_framework == "unittest" :
335- all_defs = script .get_names (all_scopes = True , definitions = True )
312+ try :
313+ script = jedi .Script (path = test_file , project = jedi_project )
314+ test_functions = set ()
336315
337- functions_to_search = [elem .test_function for elem in functions ]
338- test_suites = {elem .test_class for elem in functions }
316+ all_names = script .get_names (all_scopes = True , references = True )
317+ all_defs = script .get_names (all_scopes = True , definitions = True )
318+ all_names_top = script .get_names (all_scopes = True )
339319
340- matching_names = test_suites & top_level_classes .keys ()
341- for matched_name in matching_names :
342- for def_name in all_defs :
343- if (
344- def_name .type == "function"
345- and def_name .full_name is not None
346- and f".{ matched_name } ." in def_name .full_name
347- ):
348- for function in functions_to_search :
349- (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
320+ top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
321+ top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
322+ except Exception as e :
323+ logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
324+ progress .advance (task_id )
325+ continue
350326
351- if is_parameterized and new_function == def_name .name :
327+ if test_framework == "pytest" :
328+ for function in functions :
329+ if "[" in function .test_function :
330+ function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[0 ]
331+ parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[1 ]
332+ if function_name in top_level_functions :
352333 test_functions .add (
353- TestFunction (
354- function_name = def_name .name ,
355- test_class = matched_name ,
356- parameters = parameters ,
357- test_type = functions [0 ].test_type ,
358- )
334+ TestFunction (function_name , function .test_class , parameters , function .test_type )
359335 )
360- elif function == def_name .name :
336+ elif function .test_function in top_level_functions :
337+ test_functions .add (
338+ TestFunction (function .test_function , function .test_class , None , function .test_type )
339+ )
340+ elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (function .test_function ):
341+ base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub ("" , function .test_function )
342+ if base_name in top_level_functions :
361343 test_functions .add (
362344 TestFunction (
363- function_name = def_name . name ,
364- test_class = matched_name ,
365- parameters = None ,
366- test_type = functions [ 0 ] .test_type ,
345+ function_name = base_name ,
346+ test_class = function . test_class ,
347+ parameters = function . test_function ,
348+ test_type = function .test_type ,
367349 )
368350 )
369351
370- test_functions_list = list (test_functions )
371- test_functions_raw = [elem .function_name for elem in test_functions_list ]
372-
373- test_functions_by_name = defaultdict (list )
374- for i , func_name in enumerate (test_functions_raw ):
375- test_functions_by_name [func_name ].append (i )
376-
377- for name in all_names :
378- if name .full_name is None :
379- continue
380- m = FUNCTION_NAME_REGEX .search (name .full_name )
381- if not m :
382- continue
383-
384- scope = m .group (1 )
385- if scope not in test_functions_by_name :
386- continue
387-
388- try :
389- definition = name .goto (follow_imports = True , follow_builtin_imports = False )
390- except Exception as e :
391- logger .debug (str (e ))
392- continue
393-
394- if not definition or definition [0 ].type != "function" :
395- continue
396-
397- definition_path = str (definition [0 ].module_path )
398- if (
399- definition_path .startswith (str (project_root_path ) + os .sep )
400- and definition [0 ].module_name != name .module_name
401- and definition [0 ].full_name is not None
402- ):
403- for index in test_functions_by_name [scope ]:
404- scope_test_function = test_functions_list [index ].function_name
405- scope_test_class = test_functions_list [index ].test_class
406- scope_parameters = test_functions_list [index ].parameters
407- test_type = test_functions_list [index ].test_type
408-
409- if scope_parameters is not None :
410- if cfg .test_framework == "pytest" :
411- scope_test_function += "[" + scope_parameters + "]"
412- if cfg .test_framework == "unittest" :
413- scope_test_function += "_" + scope_parameters
414-
415- full_name_without_module_prefix = definition [0 ].full_name .replace (
416- definition [0 ].module_name + "." , "" , 1
417- )
418- qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
419-
420- tests_cache .insert_test (
421- file_path = str (test_file ),
422- file_hash = file_hash ,
423- qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
424- function_name = scope ,
425- test_class = scope_test_class ,
426- test_function = scope_test_function ,
427- test_type = test_type ,
428- line_number = name .line ,
429- col_number = name .column ,
430- )
352+ elif test_framework == "unittest" :
353+ functions_to_search = [elem .test_function for elem in functions ]
354+ test_suites = {elem .test_class for elem in functions }
355+
356+ matching_names = test_suites & top_level_classes .keys ()
357+ for matched_name in matching_names :
358+ for def_name in all_defs :
359+ if (
360+ def_name .type == "function"
361+ and def_name .full_name is not None
362+ and f".{ matched_name } ." in def_name .full_name
363+ ):
364+ for function in functions_to_search :
365+ (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
366+
367+ if is_parameterized and new_function == def_name .name :
368+ test_functions .add (
369+ TestFunction (
370+ function_name = def_name .name ,
371+ test_class = matched_name ,
372+ parameters = parameters ,
373+ test_type = functions [0 ].test_type ,
374+ )
375+ )
376+ elif function == def_name .name :
377+ test_functions .add (
378+ TestFunction (
379+ function_name = def_name .name ,
380+ test_class = matched_name ,
381+ parameters = None ,
382+ test_type = functions [0 ].test_type ,
383+ )
384+ )
385+
386+ test_functions_list = list (test_functions )
387+ test_functions_raw = [elem .function_name for elem in test_functions_list ]
388+
389+ test_functions_by_name = defaultdict (list )
390+ for i , func_name in enumerate (test_functions_raw ):
391+ test_functions_by_name [func_name ].append (i )
392+
393+ for name in all_names :
394+ if name .full_name is None :
395+ continue
396+ m = FUNCTION_NAME_REGEX .search (name .full_name )
397+ if not m :
398+ continue
399+
400+ scope = m .group (1 )
401+ if scope not in test_functions_by_name :
402+ continue
403+
404+ cache_key = (name .full_name , name .module_name )
405+ try :
406+ if cache_key in goto_cache :
407+ definition = goto_cache [cache_key ]
408+ else :
409+ definition = name .goto (follow_imports = True , follow_builtin_imports = False )
410+ goto_cache [cache_key ] = definition
411+ except Exception as e :
412+ logger .debug (str (e ))
413+ continue
414+
415+ if not definition or definition [0 ].type != "function" :
416+ continue
417+
418+ definition_path = str (definition [0 ].module_path )
419+ if (
420+ definition_path .startswith (str (project_root_path ) + os .sep )
421+ and definition [0 ].module_name != name .module_name
422+ and definition [0 ].full_name is not None
423+ ):
424+ for index in test_functions_by_name [scope ]:
425+ scope_test_function = test_functions_list [index ].function_name
426+ scope_test_class = test_functions_list [index ].test_class
427+ scope_parameters = test_functions_list [index ].parameters
428+ test_type = test_functions_list [index ].test_type
429+
430+ if scope_parameters is not None :
431+ if test_framework == "pytest" :
432+ scope_test_function += "[" + scope_parameters + "]"
433+ if test_framework == "unittest" :
434+ scope_test_function += "_" + scope_parameters
435+
436+ full_name_without_module_prefix = definition [0 ].full_name .replace (
437+ definition [0 ].module_name + "." , "" , 1
438+ )
439+ qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
431440
432- function_to_test_map [qualified_name_with_modules_from_root ].add (
433- FunctionCalledInTest (
434- tests_in_file = TestsInFile (
435- test_file = test_file ,
441+ tests_cache .insert_test (
442+ file_path = str (test_file ),
443+ file_hash = file_hash ,
444+ qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
445+ function_name = scope ,
436446 test_class = scope_test_class ,
437447 test_function = scope_test_function ,
438448 test_type = test_type ,
439- ),
440- position = CodePosition (line_no = name .line , col_no = name .column ),
441- )
442- )
443-
444- tests_cache .close ()
445- return function_to_test_map
446-
449+ line_number = name .line ,
450+ col_number = name .column ,
451+ )
447452
448- def process_test_files (
449- file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
450- ) -> dict [str , list [FunctionCalledInTest ]]:
451- project_root_path = cfg .project_root_path
453+ function_to_test_map [qualified_name_with_modules_from_root ].add (
454+ FunctionCalledInTest (
455+ tests_in_file = TestsInFile (
456+ test_file = test_file ,
457+ test_class = scope_test_class ,
458+ test_function = scope_test_function ,
459+ test_type = test_type ,
460+ ),
461+ position = CodePosition (line_no = name .line , col_no = name .column ),
462+ )
463+ )
452464
453- function_to_test_map = defaultdict (set )
454- jedi_project = jedi .Project (path = project_root_path )
455- with (
456- test_files_progress_bar (total = len (file_to_test_map ), description = "Processing test files" ) as (
457- progress ,
458- task_id ,
459- ),
460- ProcessPoolExecutor () as executor ,
461- ):
462- futures = {
463- executor .submit (process_single_test_file , test_file , functions , cfg , jedi_project ): test_file
464- for test_file , functions in file_to_test_map .items ()
465- }
466- for future in futures :
467- result = future .result ()
468- for k , v in result .items ():
469- function_to_test_map [k ].update (v )
470- progress .update (task_id )
465+ progress .advance (task_id )
471466
467+ tests_cache .close ()
472468 return {function : list (tests ) for function , tests in function_to_test_map .items ()}
0 commit comments