@@ -177,35 +177,51 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
177177 return False, function_name, None
178178
179179
180- def process_test_files(
181- file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
182- ) -> dict[str, list[FunctionCalledInTest]]:
183- from concurrent.futures import ThreadPoolExecutor
180+ # Add this worker function at the module level (outside any other function)
181+ def process_file_worker(args_tuple):
182+ """Worker function for processing a single test file in a separate process.
183+
184+ This must be at the module level (not nested) for multiprocessing to work.
185+ """
186+ import jedi
187+ import re
184188 import os
189+ from collections import defaultdict
190+ from pathlib import Path
185191
186- project_root_path = cfg.project_root_path
187- test_framework = cfg.test_framework
188- function_to_test_map = defaultdict(list)
189- jedi_project = jedi.Project(path=project_root_path)
192+ # Unpack the arguments
193+ test_file, functions, config = args_tuple
194+
195+ try:
196+ # Each process creates its own Jedi project
197+ jedi_project = jedi.Project(path=config['project_root_path'])
190198
191- # Define a function to process a single test file
192- def process_single_file(test_file, functions):
193199 local_results = defaultdict(list)
200+ tests_found_in_file = 0
201+
202+ # Convert test_file back to Path if necessary
203+ test_file_path = test_file if isinstance(test_file, Path) else Path(test_file)
204+
194205 try:
195206 script = jedi.Script(path=test_file, project=jedi_project)
196- test_functions = set()
197-
198207 all_names = script.get_names(all_scopes=True, references=True)
199208 all_defs = script.get_names(all_scopes=True, definitions=True)
200209 all_names_top = script.get_names(all_scopes=True)
201210
202211 top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
203212 top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
204213 except Exception as e:
205- logger.debug(f"Failed to get jedi script for {test_file}: {e}")
206- return local_results
214+ return {
215+ 'status': 'error',
216+ 'error_type': 'jedi_script_error',
217+ 'error_message': str(e),
218+ 'test_file': test_file,
219+ 'results': {}
220+ }
221+
222+ test_functions = set()
207223
208- if test_framework == "pytest":
224+ if config[' test_framework'] == "pytest":
209225 for function in functions:
210226 if "[" in function.test_function:
211227 function_name = re.split(r"[\[\]]", function.test_function)[0]
@@ -219,8 +235,7 @@ def process_single_file(test_file, functions):
219235 TestFunction(function.test_function, function.test_class, None, function.test_type)
220236 )
221237 elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
222- # Try to match parameterized unittest functions here, although we can't get the parameters.
223- # Extract base name by removing the numbered suffix and any additional descriptions
238+ # Try to match parameterized unittest functions here
224239 base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function)
225240 if base_name in top_level_functions:
226241 test_functions.add(
@@ -232,11 +247,11 @@ def process_single_file(test_file, functions):
232247 )
233248 )
234249
235- elif test_framework == "unittest":
250+ elif config[' test_framework'] == "unittest":
236251 functions_to_search = [elem.test_function for elem in functions]
237252 test_suites = {elem.test_class for elem in functions}
238253
239- matching_names = test_suites & top_level_classes.keys()
254+ matching_names = set( test_suites) & set( top_level_classes.keys() )
240255 for matched_name in matching_names:
241256 for def_name in all_defs:
242257 if (
@@ -254,7 +269,7 @@ def process_single_file(test_file, functions):
254269 test_class=matched_name,
255270 parameters=parameters,
256271 test_type=functions[0].test_type,
257- ) # A test file must not have more than one test type
272+ )
258273 )
259274 elif function == def_name.name:
260275 test_functions.add(
@@ -285,61 +300,181 @@ def process_single_file(test_file, functions):
285300 try:
286301 definition = name.goto(follow_imports=True, follow_builtin_imports=False)
287302 except Exception as e:
288- logger.debug(str(e))
289303 continue
290304 if definition and definition[0].type == "function":
291305 definition_path = str(definition[0].module_path)
292306 # The definition is part of this project and not defined within the original function
293307 if (
294- definition_path.startswith(str( project_root_path) + os.sep)
308+ definition_path.startswith(config[' project_root_path'] + os.sep)
295309 and definition[0].module_name != name.module_name
296310 and definition[0].full_name is not None
297311 ):
298312 if scope_parameters is not None:
299- if test_framework == "pytest":
313+ if config[' test_framework'] == "pytest":
300314 scope_test_function += "[" + scope_parameters + "]"
301- if test_framework == "unittest":
315+ if config[' test_framework'] == "unittest":
302316 scope_test_function += "_" + scope_parameters
317+
318+ # Get module name relative to project root
319+ module_name = module_name_from_file_path(definition[0].module_path, config['project_root_path'])
320+
303321 full_name_without_module_prefix = definition[0].full_name.replace(
304322 definition[0].module_name + ".", "", 1
305323 )
306- qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
307- local_results[qualified_name_with_modules_from_root].append(
308- FunctionCalledInTest(
309- tests_in_file=TestsInFile(
310- test_file=test_file,
311- test_class=scope_test_class,
312- test_function=scope_test_function,
313- test_type=test_type,
314- ),
315- position=CodePosition(line_no=name.line, col_no=name.column),
316- )
317- )
318- return local_results
319-
320- # Determine number of workers (threads) - use fewer than processes since these are I/O bound
321- max_workers = min(os.cpu_count() * 2 or 8, len(file_to_test_map), 16)
324+ qualified_name_with_modules_from_root = f"{module_name}.{full_name_without_module_prefix}"
325+
326+ # Create a serializable representation of the result
327+ result_entry = {
328+ 'test_file': str(test_file),
329+ 'test_class': scope_test_class,
330+ 'test_function': scope_test_function,
331+ 'test_type': test_type,
332+ 'line_no': name.line,
333+ 'col_no': name.column
334+ }
335+
336+ # Add to local results
337+ if qualified_name_with_modules_from_root not in local_results:
338+ local_results[qualified_name_with_modules_from_root] = []
339+ local_results[qualified_name_with_modules_from_root].append(result_entry)
340+ tests_found_in_file += 1
341+
342+ return {
343+ 'status': 'success',
344+ 'test_file': test_file,
345+ 'tests_found': tests_found_in_file,
346+ 'results': dict(local_results) # Convert defaultdict to dict for serialization
347+ }
322348
323- # Process files in parallel using threads (shared memory)
324- with ThreadPoolExecutor(max_workers=max_workers) as executor:
325- futures = {
326- executor.submit(process_single_file, test_file, functions): test_file
327- for test_file, functions in file_to_test_map.items()
349+ except Exception as e:
350+ import traceback
351+ return {
352+ 'status': 'error',
353+ 'error_type': 'general_error',
354+ 'error_message': str(e),
355+ 'traceback': traceback.format_exc(),
356+ 'test_file': test_file,
357+ 'results': {}
328358 }
329359
330- # Collect results
331- for future in futures:
332- try:
333- file_results = future.result()
334- # Merge results
335- for function, tests in file_results.items():
336- function_to_test_map[function].extend(tests)
337- except Exception as e:
338- logger.warning(f"Error processing file {futures[future]}: {e}")
360+
361+ def process_test_files(
362+ file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
363+ ) -> dict[str, list[FunctionCalledInTest]]:
364+ from multiprocessing import Pool, cpu_count
365+ import os
366+ import pickle
367+
368+ project_root_path = cfg.project_root_path
369+ test_framework = cfg.test_framework
370+
371+ logger.info(f"Starting to process {len(file_to_test_map)} test files with multiprocessing")
372+
373+ # Create a configuration dictionary to pass to worker processes
374+ config_dict = {
375+ 'project_root_path': str(project_root_path),
376+ 'test_framework': test_framework
377+ }
378+
379+ # Prepare data for processing - create a list of (test_file, functions, config) tuples
380+ process_inputs = []
381+ for test_file, functions in file_to_test_map.items():
382+ # Convert TestsInFile objects to serializable form if needed
383+ serializable_functions = []
384+ for func in functions:
385+ # Ensure test_file is a string (needed for pickling)
386+ if hasattr(func, 'test_file') and not isinstance(func.test_file, str):
387+ func_dict = func._asdict() if hasattr(func, '_asdict') else func.__dict__.copy()
388+ func_dict['test_file'] = str(func_dict['test_file'])
389+ serializable_functions.append(TestsInFile(**func_dict))
390+ else:
391+ serializable_functions.append(func)
392+ process_inputs.append((str(test_file), serializable_functions, config_dict))
393+
394+ # Determine optimal number of processes
395+ max_processes = min(cpu_count() * 2, len(process_inputs), 16)
396+ logger.info(f"Using {max_processes} processes for parallel test file processing")
397+
398+ # Create a Pool and process the files
399+ processed_files = 0
400+ error_count = 0
401+ function_to_test_map = defaultdict(list)
402+
403+ # Use smaller chunk size for better load balancing
404+ chunk_size = max(1, len(process_inputs) // (max_processes * 4))
405+
406+ with Pool(processes=max_processes) as pool:
407+ # Use imap_unordered for better performance (we don't care about order)
408+ for i, result in enumerate(pool.imap_unordered(process_file_worker, process_inputs, chunk_size)):
409+ processed_files += 1
410+
411+ # Log progress
412+ if processed_files % 100 == 0 or processed_files == len(process_inputs):
413+ logger.info(f"Processed {processed_files}/{len(process_inputs)} files")
414+
415+ if result['status'] == 'error':
416+ error_count += 1
417+ logger.warning(f"Error processing file {result['test_file']}: {result['error_message']}")
418+ if 'traceback' in result:
419+ logger.debug(f"Traceback: {result['traceback']}")
420+ continue
421+
422+ # Process results from this file
423+ for qualified_name, test_entries in result['results'].items():
424+ for entry in test_entries:
425+ # Reconstruct FunctionCalledInTest from the serialized data
426+ test_in_file = TestsInFile(
427+ test_file=entry['test_file'],
428+ test_class=entry['test_class'],
429+ test_function=entry['test_function'],
430+ test_type=entry['test_type']
431+ )
432+
433+ position = CodePosition(line_no=entry['line_no'], col_no=entry['col_no'])
434+
435+ function_to_test_map[qualified_name].append(
436+ FunctionCalledInTest(
437+ tests_in_file=test_in_file,
438+ position=position
439+ )
440+ )
441+
442+ logger.info(f"Processing complete. Processed {processed_files}/{len(process_inputs)} files")
443+ logger.info(f"Files with errors: {error_count}")
444+
445+ # Log metrics before deduplication
446+ total_tests_before_dedup = sum(len(tests) for tests in function_to_test_map.values())
447+ logger.info(
448+ f"Found {len(function_to_test_map)} unique functions with {total_tests_before_dedup} total tests before deduplication")
339449
340450 # Deduplicate results
341451 deduped_function_to_test_map = {}
342452 for function, tests in function_to_test_map.items():
343- deduped_function_to_test_map[function] = list(set(tests))
453+ # Convert to set and back to list to remove duplicates
454+ # We need to handle custom objects properly
455+ unique_tests = []
456+ seen = set()
457+
458+ for test in tests:
459+ # Create a hashable representation of the test
460+ test_hash = (
461+ str(test.tests_in_file.test_file),
462+ test.tests_in_file.test_class,
463+ test.tests_in_file.test_function,
464+ test.tests_in_file.test_type,
465+ test.position.line_no,
466+ test.position.col_no
467+ )
468+
469+ if test_hash not in seen:
470+ seen.add(test_hash)
471+ unique_tests.append(test)
472+
473+ deduped_function_to_test_map[function] = unique_tests
474+
475+ # Log metrics after deduplication
476+ total_tests_after_dedup = sum(len(tests) for tests in deduped_function_to_test_map.values())
477+ logger.info(
478+ f"After deduplication: {len(deduped_function_to_test_map)} unique functions with {total_tests_after_dedup} total tests")
344479
345480 return deduped_function_to_test_map
0 commit comments