Skip to content

Commit c6d9d71

Browse files
author
Archan Das
committed
working multiproc
1 parent 86aa1cd commit c6d9d71

File tree

2 files changed

+192
-57
lines changed

2 files changed

+192
-57
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 190 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -177,35 +177,51 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
177177
return False, function_name, None
178178

179179

180-
def process_test_files(
181-
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
182-
) -> dict[str, list[FunctionCalledInTest]]:
183-
from concurrent.futures import ThreadPoolExecutor
180+
# Add this worker function at the module level (outside any other function)
181+
def process_file_worker(args_tuple):
182+
"""Worker function for processing a single test file in a separate process.
183+
184+
This must be at the module level (not nested) for multiprocessing to work.
185+
"""
186+
import jedi
187+
import re
184188
import os
189+
from collections import defaultdict
190+
from pathlib import Path
185191

186-
project_root_path = cfg.project_root_path
187-
test_framework = cfg.test_framework
188-
function_to_test_map = defaultdict(list)
189-
jedi_project = jedi.Project(path=project_root_path)
192+
# Unpack the arguments
193+
test_file, functions, config = args_tuple
194+
195+
try:
196+
# Each process creates its own Jedi project
197+
jedi_project = jedi.Project(path=config['project_root_path'])
190198

191-
# Define a function to process a single test file
192-
def process_single_file(test_file, functions):
193199
local_results = defaultdict(list)
200+
tests_found_in_file = 0
201+
202+
# Convert test_file back to Path if necessary
203+
test_file_path = test_file if isinstance(test_file, Path) else Path(test_file)
204+
194205
try:
195206
script = jedi.Script(path=test_file, project=jedi_project)
196-
test_functions = set()
197-
198207
all_names = script.get_names(all_scopes=True, references=True)
199208
all_defs = script.get_names(all_scopes=True, definitions=True)
200209
all_names_top = script.get_names(all_scopes=True)
201210

202211
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
203212
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
204213
except Exception as e:
205-
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
206-
return local_results
214+
return {
215+
'status': 'error',
216+
'error_type': 'jedi_script_error',
217+
'error_message': str(e),
218+
'test_file': test_file,
219+
'results': {}
220+
}
221+
222+
test_functions = set()
207223

208-
if test_framework == "pytest":
224+
if config['test_framework'] == "pytest":
209225
for function in functions:
210226
if "[" in function.test_function:
211227
function_name = re.split(r"[\[\]]", function.test_function)[0]
@@ -219,8 +235,7 @@ def process_single_file(test_file, functions):
219235
TestFunction(function.test_function, function.test_class, None, function.test_type)
220236
)
221237
elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
222-
# Try to match parameterized unittest functions here, although we can't get the parameters.
223-
# Extract base name by removing the numbered suffix and any additional descriptions
238+
# Try to match parameterized unittest functions here
224239
base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function)
225240
if base_name in top_level_functions:
226241
test_functions.add(
@@ -232,11 +247,11 @@ def process_single_file(test_file, functions):
232247
)
233248
)
234249

235-
elif test_framework == "unittest":
250+
elif config['test_framework'] == "unittest":
236251
functions_to_search = [elem.test_function for elem in functions]
237252
test_suites = {elem.test_class for elem in functions}
238253

239-
matching_names = test_suites & top_level_classes.keys()
254+
matching_names = set(test_suites) & set(top_level_classes.keys())
240255
for matched_name in matching_names:
241256
for def_name in all_defs:
242257
if (
@@ -254,7 +269,7 @@ def process_single_file(test_file, functions):
254269
test_class=matched_name,
255270
parameters=parameters,
256271
test_type=functions[0].test_type,
257-
) # A test file must not have more than one test type
272+
)
258273
)
259274
elif function == def_name.name:
260275
test_functions.add(
@@ -285,61 +300,181 @@ def process_single_file(test_file, functions):
285300
try:
286301
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
287302
except Exception as e:
288-
logger.debug(str(e))
289303
continue
290304
if definition and definition[0].type == "function":
291305
definition_path = str(definition[0].module_path)
292306
# The definition is part of this project and not defined within the original function
293307
if (
294-
definition_path.startswith(str(project_root_path) + os.sep)
308+
definition_path.startswith(config['project_root_path'] + os.sep)
295309
and definition[0].module_name != name.module_name
296310
and definition[0].full_name is not None
297311
):
298312
if scope_parameters is not None:
299-
if test_framework == "pytest":
313+
if config['test_framework'] == "pytest":
300314
scope_test_function += "[" + scope_parameters + "]"
301-
if test_framework == "unittest":
315+
if config['test_framework'] == "unittest":
302316
scope_test_function += "_" + scope_parameters
317+
318+
# Get module name relative to project root
319+
module_name = module_name_from_file_path(definition[0].module_path, config['project_root_path'])
320+
303321
full_name_without_module_prefix = definition[0].full_name.replace(
304322
definition[0].module_name + ".", "", 1
305323
)
306-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
307-
local_results[qualified_name_with_modules_from_root].append(
308-
FunctionCalledInTest(
309-
tests_in_file=TestsInFile(
310-
test_file=test_file,
311-
test_class=scope_test_class,
312-
test_function=scope_test_function,
313-
test_type=test_type,
314-
),
315-
position=CodePosition(line_no=name.line, col_no=name.column),
316-
)
317-
)
318-
return local_results
319-
320-
# Determine number of workers (threads) - use fewer than processes since these are I/O bound
321-
max_workers = min(os.cpu_count() * 2 or 8, len(file_to_test_map), 16)
324+
qualified_name_with_modules_from_root = f"{module_name}.{full_name_without_module_prefix}"
325+
326+
# Create a serializable representation of the result
327+
result_entry = {
328+
'test_file': str(test_file),
329+
'test_class': scope_test_class,
330+
'test_function': scope_test_function,
331+
'test_type': test_type,
332+
'line_no': name.line,
333+
'col_no': name.column
334+
}
335+
336+
# Add to local results
337+
if qualified_name_with_modules_from_root not in local_results:
338+
local_results[qualified_name_with_modules_from_root] = []
339+
local_results[qualified_name_with_modules_from_root].append(result_entry)
340+
tests_found_in_file += 1
341+
342+
return {
343+
'status': 'success',
344+
'test_file': test_file,
345+
'tests_found': tests_found_in_file,
346+
'results': dict(local_results) # Convert defaultdict to dict for serialization
347+
}
322348

323-
# Process files in parallel using threads (shared memory)
324-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
325-
futures = {
326-
executor.submit(process_single_file, test_file, functions): test_file
327-
for test_file, functions in file_to_test_map.items()
349+
except Exception as e:
350+
import traceback
351+
return {
352+
'status': 'error',
353+
'error_type': 'general_error',
354+
'error_message': str(e),
355+
'traceback': traceback.format_exc(),
356+
'test_file': test_file,
357+
'results': {}
328358
}
329359

330-
# Collect results
331-
for future in futures:
332-
try:
333-
file_results = future.result()
334-
# Merge results
335-
for function, tests in file_results.items():
336-
function_to_test_map[function].extend(tests)
337-
except Exception as e:
338-
logger.warning(f"Error processing file {futures[future]}: {e}")
360+
361+
def process_test_files(
362+
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
363+
) -> dict[str, list[FunctionCalledInTest]]:
364+
from multiprocessing import Pool, cpu_count
365+
import os
366+
import pickle
367+
368+
project_root_path = cfg.project_root_path
369+
test_framework = cfg.test_framework
370+
371+
logger.info(f"Starting to process {len(file_to_test_map)} test files with multiprocessing")
372+
373+
# Create a configuration dictionary to pass to worker processes
374+
config_dict = {
375+
'project_root_path': str(project_root_path),
376+
'test_framework': test_framework
377+
}
378+
379+
# Prepare data for processing - create a list of (test_file, functions, config) tuples
380+
process_inputs = []
381+
for test_file, functions in file_to_test_map.items():
382+
# Convert TestsInFile objects to serializable form if needed
383+
serializable_functions = []
384+
for func in functions:
385+
# Ensure test_file is a string (needed for pickling)
386+
if hasattr(func, 'test_file') and not isinstance(func.test_file, str):
387+
func_dict = func._asdict() if hasattr(func, '_asdict') else func.__dict__.copy()
388+
func_dict['test_file'] = str(func_dict['test_file'])
389+
serializable_functions.append(TestsInFile(**func_dict))
390+
else:
391+
serializable_functions.append(func)
392+
process_inputs.append((str(test_file), serializable_functions, config_dict))
393+
394+
# Determine optimal number of processes
395+
max_processes = min(cpu_count() * 2, len(process_inputs), 16)
396+
logger.info(f"Using {max_processes} processes for parallel test file processing")
397+
398+
# Create a Pool and process the files
399+
processed_files = 0
400+
error_count = 0
401+
function_to_test_map = defaultdict(list)
402+
403+
# Use smaller chunk size for better load balancing
404+
chunk_size = max(1, len(process_inputs) // (max_processes * 4))
405+
406+
with Pool(processes=max_processes) as pool:
407+
# Use imap_unordered for better performance (we don't care about order)
408+
for i, result in enumerate(pool.imap_unordered(process_file_worker, process_inputs, chunk_size)):
409+
processed_files += 1
410+
411+
# Log progress
412+
if processed_files % 100 == 0 or processed_files == len(process_inputs):
413+
logger.info(f"Processed {processed_files}/{len(process_inputs)} files")
414+
415+
if result['status'] == 'error':
416+
error_count += 1
417+
logger.warning(f"Error processing file {result['test_file']}: {result['error_message']}")
418+
if 'traceback' in result:
419+
logger.debug(f"Traceback: {result['traceback']}")
420+
continue
421+
422+
# Process results from this file
423+
for qualified_name, test_entries in result['results'].items():
424+
for entry in test_entries:
425+
# Reconstruct FunctionCalledInTest from the serialized data
426+
test_in_file = TestsInFile(
427+
test_file=entry['test_file'],
428+
test_class=entry['test_class'],
429+
test_function=entry['test_function'],
430+
test_type=entry['test_type']
431+
)
432+
433+
position = CodePosition(line_no=entry['line_no'], col_no=entry['col_no'])
434+
435+
function_to_test_map[qualified_name].append(
436+
FunctionCalledInTest(
437+
tests_in_file=test_in_file,
438+
position=position
439+
)
440+
)
441+
442+
logger.info(f"Processing complete. Processed {processed_files}/{len(process_inputs)} files")
443+
logger.info(f"Files with errors: {error_count}")
444+
445+
# Log metrics before deduplication
446+
total_tests_before_dedup = sum(len(tests) for tests in function_to_test_map.values())
447+
logger.info(
448+
f"Found {len(function_to_test_map)} unique functions with {total_tests_before_dedup} total tests before deduplication")
339449

340450
# Deduplicate results
341451
deduped_function_to_test_map = {}
342452
for function, tests in function_to_test_map.items():
343-
deduped_function_to_test_map[function] = list(set(tests))
453+
# Convert to set and back to list to remove duplicates
454+
# We need to handle custom objects properly
455+
unique_tests = []
456+
seen = set()
457+
458+
for test in tests:
459+
# Create a hashable representation of the test
460+
test_hash = (
461+
str(test.tests_in_file.test_file),
462+
test.tests_in_file.test_class,
463+
test.tests_in_file.test_function,
464+
test.tests_in_file.test_type,
465+
test.position.line_no,
466+
test.position.col_no
467+
)
468+
469+
if test_hash not in seen:
470+
seen.add(test_hash)
471+
unique_tests.append(test)
472+
473+
deduped_function_to_test_map[function] = unique_tests
474+
475+
# Log metrics after deduplication
476+
total_tests_after_dedup = sum(len(tests) for tests in deduped_function_to_test_map.values())
477+
logger.info(
478+
f"After deduplication: {len(deduped_function_to_test_map)} unique functions with {total_tests_after_dedup} total tests")
344479

345480
return deduped_function_to_test_map

codeflash/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
2-
__version__ = "0.10.3.post6.dev0+a035af84"
3-
__version_tuple__ = (0, 10, 3, "post6", "dev0", "a035af84")
2+
__version__ = "0.10.3.post7.dev0+86aa1cd6"
3+
__version_tuple__ = (0, 10, 3, "post7", "dev0", "86aa1cd6")

0 commit comments

Comments
 (0)