From 86aa1cd6fe637d10e711d4e75e601fde5ff0ea58 Mon Sep 17 00:00:00 2001 From: Archan Das Date: Wed, 19 Mar 2025 21:28:35 +0000 Subject: [PATCH 1/4] execute test discovery in threads --- codeflash/discovery/discover_unit_tests.py | 51 +++++++++++++++++----- codeflash/version.py | 4 +- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 02ae2e4c1..e952567da 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -178,14 +178,19 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N def process_test_files( - file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig + file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig ) -> dict[str, list[FunctionCalledInTest]]: + from concurrent.futures import ThreadPoolExecutor + import os + project_root_path = cfg.project_root_path test_framework = cfg.test_framework function_to_test_map = defaultdict(list) jedi_project = jedi.Project(path=project_root_path) - for test_file, functions in file_to_test_map.items(): + # Define a function to process a single test file + def process_single_file(test_file, functions): + local_results = defaultdict(list) try: script = jedi.Script(path=test_file, project=jedi_project) test_functions = set() @@ -198,7 +203,7 @@ def process_test_files( top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} except Exception as e: logger.debug(f"Failed to get jedi script for {test_file}: {e}") - continue + return local_results if test_framework == "pytest": for function in functions: @@ -229,15 +234,15 @@ def process_test_files( elif test_framework == "unittest": functions_to_search = [elem.test_function for elem in functions] - test_suites = [elem.test_class for elem in functions] + test_suites = {elem.test_class for elem in functions} matching_names = test_suites & top_level_classes.keys() for matched_name in matching_names: for def_name in all_defs: if ( - def_name.type == "function" - and def_name.full_name is not None - and f".{matched_name}." in def_name.full_name + def_name.type == "function" + and def_name.full_name is not None + and f".{matched_name}." in def_name.full_name ): for function in functions_to_search: (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) @@ -286,9 +291,9 @@ def process_test_files( definition_path = str(definition[0].module_path) # The definition is part of this project and not defined within the original function if ( - definition_path.startswith(str(project_root_path) + os.sep) - and definition[0].module_name != name.module_name - and definition[0].full_name is not None + definition_path.startswith(str(project_root_path) + os.sep) + and definition[0].module_name != name.module_name + and definition[0].full_name is not None ): if scope_parameters is not None: if test_framework == "pytest": @@ -299,7 +304,7 @@ def process_test_files( definition[0].module_name + ".", "", 1 ) qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - function_to_test_map[qualified_name_with_modules_from_root].append( + local_results[qualified_name_with_modules_from_root].append( FunctionCalledInTest( tests_in_file=TestsInFile( test_file=test_file, @@ -310,7 +315,31 @@ def process_test_files( position=CodePosition(line_no=name.line, col_no=name.column), ) ) + return local_results + + # Determine number of workers (threads) - use fewer than processes since these are I/O bound + max_workers = min(os.cpu_count() * 2 or 8, len(file_to_test_map), 16) + + # Process files in parallel using threads (shared memory) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(process_single_file, test_file, functions): test_file + for test_file, functions in file_to_test_map.items() + } + + # Collect results + for future in futures: + try: + file_results = future.result() + # Merge results + for function, tests in file_results.items(): + function_to_test_map[function].extend(tests) + except Exception as e: + logger.warning(f"Error processing file {futures[future]}: {e}") + + # Deduplicate results deduped_function_to_test_map = {} for function, tests in function_to_test_map.items(): deduped_function_to_test_map[function] = list(set(tests)) + return deduped_function_to_test_map diff --git a/codeflash/version.py b/codeflash/version.py index 075de2843..7387724b7 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,3 +1,3 @@ # These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`. -__version__ = "0.10.3" -__version_tuple__ = (0, 10, 3) +__version__ = "0.10.3.post6.dev0+a035af84" +__version_tuple__ = (0, 10, 3, "post6", "dev0", "a035af84") From c6d9d710615f8861f674915d02ddfa925e287f0e Mon Sep 17 00:00:00 2001 From: Archan Das Date: Wed, 19 Mar 2025 22:12:04 +0000 Subject: [PATCH 2/4] working multiproc --- codeflash/discovery/discover_unit_tests.py | 245 ++++++++++++++++----- codeflash/version.py | 4 +- 2 files changed, 192 insertions(+), 57 deletions(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index e952567da..676dbe96a 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -177,24 +177,33 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N return False, function_name, None -def process_test_files( - file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig -) -> dict[str, list[FunctionCalledInTest]]: - from concurrent.futures import ThreadPoolExecutor +# Add this worker function at the module level (outside any other function) +def process_file_worker(args_tuple): + """Worker function for processing a single test file in a separate process. + + This must be at the module level (not nested) for multiprocessing to work. + """ + import jedi + import re import os + from collections import defaultdict + from pathlib import Path - project_root_path = cfg.project_root_path - test_framework = cfg.test_framework - function_to_test_map = defaultdict(list) - jedi_project = jedi.Project(path=project_root_path) + # Unpack the arguments + test_file, functions, config = args_tuple + + try: + # Each process creates its own Jedi project + jedi_project = jedi.Project(path=config['project_root_path']) - # Define a function to process a single test file - def process_single_file(test_file, functions): local_results = defaultdict(list) + tests_found_in_file = 0 + + # Convert test_file back to Path if necessary + test_file_path = test_file if isinstance(test_file, Path) else Path(test_file) + try: script = jedi.Script(path=test_file, project=jedi_project) - test_functions = set() - all_names = script.get_names(all_scopes=True, references=True) all_defs = script.get_names(all_scopes=True, definitions=True) all_names_top = script.get_names(all_scopes=True) @@ -202,10 +211,17 @@ def process_single_file(test_file, functions): top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} except Exception as e: - logger.debug(f"Failed to get jedi script for {test_file}: {e}") - return local_results + return { + 'status': 'error', + 'error_type': 'jedi_script_error', + 'error_message': str(e), + 'test_file': test_file, + 'results': {} + } + + test_functions = set() - if test_framework == "pytest": + if config['test_framework'] == "pytest": for function in functions: if "[" in function.test_function: function_name = re.split(r"[\[\]]", function.test_function)[0] @@ -219,8 +235,7 @@ def process_single_file(test_file, functions): TestFunction(function.test_function, function.test_class, None, function.test_type) ) elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function): - # Try to match parameterized unittest functions here, although we can't get the parameters. - # Extract base name by removing the numbered suffix and any additional descriptions + # Try to match parameterized unittest functions here base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function) if base_name in top_level_functions: test_functions.add( @@ -232,11 +247,11 @@ def process_single_file(test_file, functions): ) ) - elif test_framework == "unittest": + elif config['test_framework'] == "unittest": functions_to_search = [elem.test_function for elem in functions] test_suites = {elem.test_class for elem in functions} - matching_names = test_suites & top_level_classes.keys() + matching_names = set(test_suites) & set(top_level_classes.keys()) for matched_name in matching_names: for def_name in all_defs: if ( @@ -254,7 +269,7 @@ def process_single_file(test_file, functions): test_class=matched_name, parameters=parameters, test_type=functions[0].test_type, - ) # A test file must not have more than one test type + ) ) elif function == def_name.name: test_functions.add( @@ -285,61 +300,181 @@ def process_single_file(test_file, functions): try: definition = name.goto(follow_imports=True, follow_builtin_imports=False) except Exception as e: - logger.debug(str(e)) continue if definition and definition[0].type == "function": definition_path = str(definition[0].module_path) # The definition is part of this project and not defined within the original function if ( - definition_path.startswith(str(project_root_path) + os.sep) + definition_path.startswith(config['project_root_path'] + os.sep) and definition[0].module_name != name.module_name and definition[0].full_name is not None ): if scope_parameters is not None: - if test_framework == "pytest": + if config['test_framework'] == "pytest": scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": + if config['test_framework'] == "unittest": scope_test_function += "_" + scope_parameters + + # Get module name relative to project root + module_name = module_name_from_file_path(definition[0].module_path, config['project_root_path']) + full_name_without_module_prefix = definition[0].full_name.replace( definition[0].module_name + ".", "", 1 ) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - local_results[qualified_name_with_modules_from_root].append( - FunctionCalledInTest( - tests_in_file=TestsInFile( - test_file=test_file, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - ), - position=CodePosition(line_no=name.line, col_no=name.column), - ) - ) - return local_results - - # Determine number of workers (threads) - use fewer than processes since these are I/O bound - max_workers = min(os.cpu_count() * 2 or 8, len(file_to_test_map), 16) + qualified_name_with_modules_from_root = f"{module_name}.{full_name_without_module_prefix}" + + # Create a serializable representation of the result + result_entry = { + 'test_file': str(test_file), + 'test_class': scope_test_class, + 'test_function': scope_test_function, + 'test_type': test_type, + 'line_no': name.line, + 'col_no': name.column + } + + # Add to local results + if qualified_name_with_modules_from_root not in local_results: + local_results[qualified_name_with_modules_from_root] = [] + local_results[qualified_name_with_modules_from_root].append(result_entry) + tests_found_in_file += 1 + + return { + 'status': 'success', + 'test_file': test_file, + 'tests_found': tests_found_in_file, + 'results': dict(local_results) # Convert defaultdict to dict for serialization + } - # Process files in parallel using threads (shared memory) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = { - executor.submit(process_single_file, test_file, functions): test_file - for test_file, functions in file_to_test_map.items() + except Exception as e: + import traceback + return { + 'status': 'error', + 'error_type': 'general_error', + 'error_message': str(e), + 'traceback': traceback.format_exc(), + 'test_file': test_file, + 'results': {} } - # Collect results - for future in futures: - try: - file_results = future.result() - # Merge results - for function, tests in file_results.items(): - function_to_test_map[function].extend(tests) - except Exception as e: - logger.warning(f"Error processing file {futures[future]}: {e}") + +def process_test_files( + file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig +) -> dict[str, list[FunctionCalledInTest]]: + from multiprocessing import Pool, cpu_count + import os + import pickle + + project_root_path = cfg.project_root_path + test_framework = cfg.test_framework + + logger.info(f"Starting to process {len(file_to_test_map)} test files with multiprocessing") + + # Create a configuration dictionary to pass to worker processes + config_dict = { + 'project_root_path': str(project_root_path), + 'test_framework': test_framework + } + + # Prepare data for processing - create a list of (test_file, functions, config) tuples + process_inputs = [] + for test_file, functions in file_to_test_map.items(): + # Convert TestsInFile objects to serializable form if needed + serializable_functions = [] + for func in functions: + # Ensure test_file is a string (needed for pickling) + if hasattr(func, 'test_file') and not isinstance(func.test_file, str): + func_dict = func._asdict() if hasattr(func, '_asdict') else func.__dict__.copy() + func_dict['test_file'] = str(func_dict['test_file']) + serializable_functions.append(TestsInFile(**func_dict)) + else: + serializable_functions.append(func) + process_inputs.append((str(test_file), serializable_functions, config_dict)) + + # Determine optimal number of processes + max_processes = min(cpu_count() * 2, len(process_inputs), 16) + logger.info(f"Using {max_processes} processes for parallel test file processing") + + # Create a Pool and process the files + processed_files = 0 + error_count = 0 + function_to_test_map = defaultdict(list) + + # Use smaller chunk size for better load balancing + chunk_size = max(1, len(process_inputs) // (max_processes * 4)) + + with Pool(processes=max_processes) as pool: + # Use imap_unordered for better performance (we don't care about order) + for i, result in enumerate(pool.imap_unordered(process_file_worker, process_inputs, chunk_size)): + processed_files += 1 + + # Log progress + if processed_files % 100 == 0 or processed_files == len(process_inputs): + logger.info(f"Processed {processed_files}/{len(process_inputs)} files") + + if result['status'] == 'error': + error_count += 1 + logger.warning(f"Error processing file {result['test_file']}: {result['error_message']}") + if 'traceback' in result: + logger.debug(f"Traceback: {result['traceback']}") + continue + + # Process results from this file + for qualified_name, test_entries in result['results'].items(): + for entry in test_entries: + # Reconstruct FunctionCalledInTest from the serialized data + test_in_file = TestsInFile( + test_file=entry['test_file'], + test_class=entry['test_class'], + test_function=entry['test_function'], + test_type=entry['test_type'] + ) + + position = CodePosition(line_no=entry['line_no'], col_no=entry['col_no']) + + function_to_test_map[qualified_name].append( + FunctionCalledInTest( + tests_in_file=test_in_file, + position=position + ) + ) + + logger.info(f"Processing complete. Processed {processed_files}/{len(process_inputs)} files") + logger.info(f"Files with errors: {error_count}") + + # Log metrics before deduplication + total_tests_before_dedup = sum(len(tests) for tests in function_to_test_map.values()) + logger.info( + f"Found {len(function_to_test_map)} unique functions with {total_tests_before_dedup} total tests before deduplication") # Deduplicate results deduped_function_to_test_map = {} for function, tests in function_to_test_map.items(): - deduped_function_to_test_map[function] = list(set(tests)) + # Convert to set and back to list to remove duplicates + # We need to handle custom objects properly + unique_tests = [] + seen = set() + + for test in tests: + # Create a hashable representation of the test + test_hash = ( + str(test.tests_in_file.test_file), + test.tests_in_file.test_class, + test.tests_in_file.test_function, + test.tests_in_file.test_type, + test.position.line_no, + test.position.col_no + ) + + if test_hash not in seen: + seen.add(test_hash) + unique_tests.append(test) + + deduped_function_to_test_map[function] = unique_tests + + # Log metrics after deduplication + total_tests_after_dedup = sum(len(tests) for tests in deduped_function_to_test_map.values()) + logger.info( + f"After deduplication: {len(deduped_function_to_test_map)} unique functions with {total_tests_after_dedup} total tests") return deduped_function_to_test_map diff --git a/codeflash/version.py b/codeflash/version.py index 7387724b7..05cc8998f 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,3 +1,3 @@ # These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`. -__version__ = "0.10.3.post6.dev0+a035af84" -__version_tuple__ = (0, 10, 3, "post6", "dev0", "a035af84") +__version__ = "0.10.3.post7.dev0+86aa1cd6" +__version_tuple__ = (0, 10, 3, "post7", "dev0", "86aa1cd6") From 0000c290009c27a54bf8dc0f3ff09a2198012e96 Mon Sep 17 00:00:00 2001 From: Archan Das Date: Wed, 19 Mar 2025 22:14:58 +0000 Subject: [PATCH 3/4] revert to single-threaded if <25 files --- codeflash/discovery/discover_unit_tests.py | 143 +++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 676dbe96a..52f08ca20 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -109,6 +109,8 @@ def discover_tests_pytest( continue file_to_test_map[test_obj.test_file].append(test_obj) # Within these test files, find the project functions they are referring to and return their names/locations + if len(file_to_test_map) < 25: #default to single-threaded if there aren't that many files + return process_test_files_single_threaded(file_to_test_map, cfg) return process_test_files(file_to_test_map, cfg) @@ -166,6 +168,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: details = get_test_details(test) if details is not None: file_to_test_map[str(details.test_file)].append(details) + + if len(file_to_test_map) < 25: #default to single-threaded if there aren't that many files + return process_test_files_single_threaded(file_to_test_map, cfg) return process_test_files(file_to_test_map, cfg) @@ -357,6 +362,144 @@ def process_file_worker(args_tuple): 'results': {} } +def process_test_files_single_threaded( + file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig +) -> dict[str, list[FunctionCalledInTest]]: + project_root_path = cfg.project_root_path + test_framework = cfg.test_framework + function_to_test_map = defaultdict(list) + jedi_project = jedi.Project(path=project_root_path) + + for test_file, functions in file_to_test_map.items(): + try: + script = jedi.Script(path=test_file, project=jedi_project) + test_functions = set() + + all_names = script.get_names(all_scopes=True, references=True) + all_defs = script.get_names(all_scopes=True, definitions=True) + all_names_top = script.get_names(all_scopes=True) + + top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} + top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} + except Exception as e: + logger.debug(f"Failed to get jedi script for {test_file}: {e}") + continue + + if test_framework == "pytest": + for function in functions: + if "[" in function.test_function: + function_name = re.split(r"[\[\]]", function.test_function)[0] + parameters = re.split(r"[\[\]]", function.test_function)[1] + if function_name in top_level_functions: + test_functions.add( + TestFunction(function_name, function.test_class, parameters, function.test_type) + ) + elif function.test_function in top_level_functions: + test_functions.add( + TestFunction(function.test_function, function.test_class, None, function.test_type) + ) + elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function): + # Try to match parameterized unittest functions here, although we can't get the parameters. + # Extract base name by removing the numbered suffix and any additional descriptions + base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function) + if base_name in top_level_functions: + test_functions.add( + TestFunction( + function_name=base_name, + test_class=function.test_class, + parameters=function.test_function, + test_type=function.test_type, + ) + ) + + elif test_framework == "unittest": + functions_to_search = [elem.test_function for elem in functions] + test_suites = [elem.test_class for elem in functions] + + matching_names = test_suites & top_level_classes.keys() + for matched_name in matching_names: + for def_name in all_defs: + if ( + def_name.type == "function" + and def_name.full_name is not None + and f".{matched_name}." in def_name.full_name + ): + for function in functions_to_search: + (is_parameterized, new_function, parameters) = discover_parameters_unittest(function) + + if is_parameterized and new_function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=parameters, + test_type=functions[0].test_type, + ) # A test file must not have more than one test type + ) + elif function == def_name.name: + test_functions.add( + TestFunction( + function_name=def_name.name, + test_class=matched_name, + parameters=None, + test_type=functions[0].test_type, + ) + ) + + test_functions_list = list(test_functions) + test_functions_raw = [elem.function_name for elem in test_functions_list] + + for name in all_names: + if name.full_name is None: + continue + m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name) + if not m: + continue + scope = m.group(1) + indices = [i for i, x in enumerate(test_functions_raw) if x == scope] + for index in indices: + scope_test_function = test_functions_list[index].function_name + scope_test_class = test_functions_list[index].test_class + scope_parameters = test_functions_list[index].parameters + test_type = test_functions_list[index].test_type + try: + definition = name.goto(follow_imports=True, follow_builtin_imports=False) + except Exception as e: + logger.debug(str(e)) + continue + if definition and definition[0].type == "function": + definition_path = str(definition[0].module_path) + # The definition is part of this project and not defined within the original function + if ( + definition_path.startswith(str(project_root_path) + os.sep) + and definition[0].module_name != name.module_name + and definition[0].full_name is not None + ): + if scope_parameters is not None: + if test_framework == "pytest": + scope_test_function += "[" + scope_parameters + "]" + if test_framework == "unittest": + scope_test_function += "_" + scope_parameters + full_name_without_module_prefix = definition[0].full_name.replace( + definition[0].module_name + ".", "", 1 + ) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" + function_to_test_map[qualified_name_with_modules_from_root].append( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=scope_test_class, + test_function=scope_test_function, + test_type=test_type, + ), + position=CodePosition(line_no=name.line, col_no=name.column), + ) + ) + deduped_function_to_test_map = {} + for function, tests in function_to_test_map.items(): + deduped_function_to_test_map[function] = list(set(tests)) + return deduped_function_to_test_map + def process_test_files( file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig From d4bb93ab0be2efe39c777ca171aa091b1a093dea Mon Sep 17 00:00:00 2001 From: Archan Das Date: Wed, 19 Mar 2025 22:15:50 +0000 Subject: [PATCH 4/4] increase ceiling for processes --- codeflash/discovery/discover_unit_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 52f08ca20..b9d73e5a4 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -535,7 +535,7 @@ def process_test_files( process_inputs.append((str(test_file), serializable_functions, config_dict)) # Determine optimal number of processes - max_processes = min(cpu_count() * 2, len(process_inputs), 16) + max_processes = min(cpu_count() * 2, len(process_inputs), 32) logger.info(f"Using {max_processes} processes for parallel test file processing") # Create a Pool and process the files