@@ -177,35 +177,51 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
177177 return False , function_name , None
178178
179179
180- def process_test_files (
181- file_to_test_map : dict [str , list [TestsInFile ]], cfg : TestConfig
182- ) -> dict [str , list [FunctionCalledInTest ]]:
183- from concurrent .futures import ThreadPoolExecutor
180+ # Add this worker function at the module level (outside any other function)
181+ def process_file_worker (args_tuple ):
182+ """Worker function for processing a single test file in a separate process.
183+
184+ This must be at the module level (not nested) for multiprocessing to work.
185+ """
186+ import jedi
187+ import re
184188 import os
189+ from collections import defaultdict
190+ from pathlib import Path
185191
186- project_root_path = cfg .project_root_path
187- test_framework = cfg .test_framework
188- function_to_test_map = defaultdict (list )
189- jedi_project = jedi .Project (path = project_root_path )
192+ # Unpack the arguments
193+ test_file , functions , config = args_tuple
194+
195+ try :
196+ # Each process creates its own Jedi project
197+ jedi_project = jedi .Project (path = config ['project_root_path' ])
190198
191- # Define a function to process a single test file
192- def process_single_file (test_file , functions ):
193199 local_results = defaultdict (list )
200+ tests_found_in_file = 0
201+
202+ # Convert test_file back to Path if necessary
203+ test_file_path = test_file if isinstance (test_file , Path ) else Path (test_file )
204+
194205 try :
195206 script = jedi .Script (path = test_file , project = jedi_project )
196- test_functions = set ()
197-
198207 all_names = script .get_names (all_scopes = True , references = True )
199208 all_defs = script .get_names (all_scopes = True , definitions = True )
200209 all_names_top = script .get_names (all_scopes = True )
201210
202211 top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
203212 top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
204213 except Exception as e :
205- logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
206- return local_results
214+ return {
215+ 'status' : 'error' ,
216+ 'error_type' : 'jedi_script_error' ,
217+ 'error_message' : str (e ),
218+ 'test_file' : test_file ,
219+ 'results' : {}
220+ }
221+
222+ test_functions = set ()
207223
208- if test_framework == "pytest" :
224+ if config [ ' test_framework' ] == "pytest" :
209225 for function in functions :
210226 if "[" in function .test_function :
211227 function_name = re .split (r"[\[\]]" , function .test_function )[0 ]
@@ -219,8 +235,7 @@ def process_single_file(test_file, functions):
219235 TestFunction (function .test_function , function .test_class , None , function .test_type )
220236 )
221237 elif re .match (r"^test_\w+_\d+(?:_\w+)*" , function .test_function ):
222- # Try to match parameterized unittest functions here, although we can't get the parameters.
223- # Extract base name by removing the numbered suffix and any additional descriptions
238+ # Try to match parameterized unittest functions here
224239 base_name = re .sub (r"_\d+(?:_\w+)*$" , "" , function .test_function )
225240 if base_name in top_level_functions :
226241 test_functions .add (
@@ -232,11 +247,11 @@ def process_single_file(test_file, functions):
232247 )
233248 )
234249
235- elif test_framework == "unittest" :
250+ elif config [ ' test_framework' ] == "unittest" :
236251 functions_to_search = [elem .test_function for elem in functions ]
237252 test_suites = {elem .test_class for elem in functions }
238253
239- matching_names = test_suites & top_level_classes .keys ()
254+ matching_names = set ( test_suites ) & set ( top_level_classes .keys () )
240255 for matched_name in matching_names :
241256 for def_name in all_defs :
242257 if (
@@ -254,7 +269,7 @@ def process_single_file(test_file, functions):
254269 test_class = matched_name ,
255270 parameters = parameters ,
256271 test_type = functions [0 ].test_type ,
257- ) # A test file must not have more than one test type
272+ )
258273 )
259274 elif function == def_name .name :
260275 test_functions .add (
@@ -285,61 +300,181 @@ def process_single_file(test_file, functions):
285300 try :
286301 definition = name .goto (follow_imports = True , follow_builtin_imports = False )
287302 except Exception as e :
288- logger .debug (str (e ))
289303 continue
290304 if definition and definition [0 ].type == "function" :
291305 definition_path = str (definition [0 ].module_path )
292306 # The definition is part of this project and not defined within the original function
293307 if (
294- definition_path .startswith (str ( project_root_path ) + os .sep )
308+ definition_path .startswith (config [ ' project_root_path' ] + os .sep )
295309 and definition [0 ].module_name != name .module_name
296310 and definition [0 ].full_name is not None
297311 ):
298312 if scope_parameters is not None :
299- if test_framework == "pytest" :
313+ if config [ ' test_framework' ] == "pytest" :
300314 scope_test_function += "[" + scope_parameters + "]"
301- if test_framework == "unittest" :
315+ if config [ ' test_framework' ] == "unittest" :
302316 scope_test_function += "_" + scope_parameters
317+
318+ # Get module name relative to project root
319+ module_name = module_name_from_file_path (definition [0 ].module_path , config ['project_root_path' ])
320+
303321 full_name_without_module_prefix = definition [0 ].full_name .replace (
304322 definition [0 ].module_name + "." , "" , 1
305323 )
306- qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
307- local_results [qualified_name_with_modules_from_root ].append (
308- FunctionCalledInTest (
309- tests_in_file = TestsInFile (
310- test_file = test_file ,
311- test_class = scope_test_class ,
312- test_function = scope_test_function ,
313- test_type = test_type ,
314- ),
315- position = CodePosition (line_no = name .line , col_no = name .column ),
316- )
317- )
318- return local_results
319-
320- # Determine number of workers (threads) - use fewer than processes since these are I/O bound
321- max_workers = min (os .cpu_count () * 2 or 8 , len (file_to_test_map ), 16 )
324+ qualified_name_with_modules_from_root = f"{ module_name } .{ full_name_without_module_prefix } "
325+
326+ # Create a serializable representation of the result
327+ result_entry = {
328+ 'test_file' : str (test_file ),
329+ 'test_class' : scope_test_class ,
330+ 'test_function' : scope_test_function ,
331+ 'test_type' : test_type ,
332+ 'line_no' : name .line ,
333+ 'col_no' : name .column
334+ }
335+
336+ # Add to local results
337+ if qualified_name_with_modules_from_root not in local_results :
338+ local_results [qualified_name_with_modules_from_root ] = []
339+ local_results [qualified_name_with_modules_from_root ].append (result_entry )
340+ tests_found_in_file += 1
341+
342+ return {
343+ 'status' : 'success' ,
344+ 'test_file' : test_file ,
345+ 'tests_found' : tests_found_in_file ,
346+ 'results' : dict (local_results ) # Convert defaultdict to dict for serialization
347+ }
322348
323- # Process files in parallel using threads (shared memory)
324- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
325- futures = {
326- executor .submit (process_single_file , test_file , functions ): test_file
327- for test_file , functions in file_to_test_map .items ()
349+ except Exception as e :
350+ import traceback
351+ return {
352+ 'status' : 'error' ,
353+ 'error_type' : 'general_error' ,
354+ 'error_message' : str (e ),
355+ 'traceback' : traceback .format_exc (),
356+ 'test_file' : test_file ,
357+ 'results' : {}
328358 }
329359
330- # Collect results
331- for future in futures :
332- try :
333- file_results = future .result ()
334- # Merge results
335- for function , tests in file_results .items ():
336- function_to_test_map [function ].extend (tests )
337- except Exception as e :
338- logger .warning (f"Error processing file { futures [future ]} : { e } " )
360+
361+ def process_test_files (
362+ file_to_test_map : dict [str , list [TestsInFile ]], cfg : TestConfig
363+ ) -> dict [str , list [FunctionCalledInTest ]]:
364+ from multiprocessing import Pool , cpu_count
365+ import os
366+ import pickle
367+
368+ project_root_path = cfg .project_root_path
369+ test_framework = cfg .test_framework
370+
371+ logger .info (f"Starting to process { len (file_to_test_map )} test files with multiprocessing" )
372+
373+ # Create a configuration dictionary to pass to worker processes
374+ config_dict = {
375+ 'project_root_path' : str (project_root_path ),
376+ 'test_framework' : test_framework
377+ }
378+
379+ # Prepare data for processing - create a list of (test_file, functions, config) tuples
380+ process_inputs = []
381+ for test_file , functions in file_to_test_map .items ():
382+ # Convert TestsInFile objects to serializable form if needed
383+ serializable_functions = []
384+ for func in functions :
385+ # Ensure test_file is a string (needed for pickling)
386+ if hasattr (func , 'test_file' ) and not isinstance (func .test_file , str ):
387+ func_dict = func ._asdict () if hasattr (func , '_asdict' ) else func .__dict__ .copy ()
388+ func_dict ['test_file' ] = str (func_dict ['test_file' ])
389+ serializable_functions .append (TestsInFile (** func_dict ))
390+ else :
391+ serializable_functions .append (func )
392+ process_inputs .append ((str (test_file ), serializable_functions , config_dict ))
393+
394+ # Determine optimal number of processes
395+ max_processes = min (cpu_count () * 2 , len (process_inputs ), 16 )
396+ logger .info (f"Using { max_processes } processes for parallel test file processing" )
397+
398+ # Create a Pool and process the files
399+ processed_files = 0
400+ error_count = 0
401+ function_to_test_map = defaultdict (list )
402+
403+ # Use smaller chunk size for better load balancing
404+ chunk_size = max (1 , len (process_inputs ) // (max_processes * 4 ))
405+
406+ with Pool (processes = max_processes ) as pool :
407+ # Use imap_unordered for better performance (we don't care about order)
408+ for i , result in enumerate (pool .imap_unordered (process_file_worker , process_inputs , chunk_size )):
409+ processed_files += 1
410+
411+ # Log progress
412+ if processed_files % 100 == 0 or processed_files == len (process_inputs ):
413+ logger .info (f"Processed { processed_files } /{ len (process_inputs )} files" )
414+
415+ if result ['status' ] == 'error' :
416+ error_count += 1
417+ logger .warning (f"Error processing file { result ['test_file' ]} : { result ['error_message' ]} " )
418+ if 'traceback' in result :
419+ logger .debug (f"Traceback: { result ['traceback' ]} " )
420+ continue
421+
422+ # Process results from this file
423+ for qualified_name , test_entries in result ['results' ].items ():
424+ for entry in test_entries :
425+ # Reconstruct FunctionCalledInTest from the serialized data
426+ test_in_file = TestsInFile (
427+ test_file = entry ['test_file' ],
428+ test_class = entry ['test_class' ],
429+ test_function = entry ['test_function' ],
430+ test_type = entry ['test_type' ]
431+ )
432+
433+ position = CodePosition (line_no = entry ['line_no' ], col_no = entry ['col_no' ])
434+
435+ function_to_test_map [qualified_name ].append (
436+ FunctionCalledInTest (
437+ tests_in_file = test_in_file ,
438+ position = position
439+ )
440+ )
441+
442+ logger .info (f"Processing complete. Processed { processed_files } /{ len (process_inputs )} files" )
443+ logger .info (f"Files with errors: { error_count } " )
444+
445+ # Log metrics before deduplication
446+ total_tests_before_dedup = sum (len (tests ) for tests in function_to_test_map .values ())
447+ logger .info (
448+ f"Found { len (function_to_test_map )} unique functions with { total_tests_before_dedup } total tests before deduplication" )
339449
340450 # Deduplicate results
341451 deduped_function_to_test_map = {}
342452 for function , tests in function_to_test_map .items ():
343- deduped_function_to_test_map [function ] = list (set (tests ))
453+ # Convert to set and back to list to remove duplicates
454+ # We need to handle custom objects properly
455+ unique_tests = []
456+ seen = set ()
457+
458+ for test in tests :
459+ # Create a hashable representation of the test
460+ test_hash = (
461+ str (test .tests_in_file .test_file ),
462+ test .tests_in_file .test_class ,
463+ test .tests_in_file .test_function ,
464+ test .tests_in_file .test_type ,
465+ test .position .line_no ,
466+ test .position .col_no
467+ )
468+
469+ if test_hash not in seen :
470+ seen .add (test_hash )
471+ unique_tests .append (test )
472+
473+ deduped_function_to_test_map [function ] = unique_tests
474+
475+ # Log metrics after deduplication
476+ total_tests_after_dedup = sum (len (tests ) for tests in deduped_function_to_test_map .values ())
477+ logger .info (
478+ f"After deduplication: { len (deduped_function_to_test_map )} unique functions with { total_tests_after_dedup } total tests" )
344479
345480 return deduped_function_to_test_map
0 commit comments