Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 308 additions & 1 deletion codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def discover_tests_pytest(
continue
file_to_test_map[test_obj.test_file].append(test_obj)
# Within these test files, find the project functions they are referring to and return their names/locations
if len(file_to_test_map) < 25: #default to single-threaded if there aren't that many files
return process_test_files_single_threaded(file_to_test_map, cfg)
return process_test_files(file_to_test_map, cfg)


Expand Down Expand Up @@ -166,6 +168,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
details = get_test_details(test)
if details is not None:
file_to_test_map[str(details.test_file)].append(details)

if len(file_to_test_map) < 25: #default to single-threaded if there aren't that many files
return process_test_files_single_threaded(file_to_test_map, cfg)
return process_test_files(file_to_test_map, cfg)


Expand All @@ -177,7 +182,187 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
return False, function_name, None


def process_test_files(
# Add this worker function at the module level (outside any other function)
def process_file_worker(args_tuple):
"""Worker function for processing a single test file in a separate process.

This must be at the module level (not nested) for multiprocessing to work.
"""
import jedi
import re
import os
from collections import defaultdict
from pathlib import Path

# Unpack the arguments
test_file, functions, config = args_tuple

try:
# Each process creates its own Jedi project
jedi_project = jedi.Project(path=config['project_root_path'])

local_results = defaultdict(list)
tests_found_in_file = 0

# Convert test_file back to Path if necessary
test_file_path = test_file if isinstance(test_file, Path) else Path(test_file)

try:
script = jedi.Script(path=test_file, project=jedi_project)
all_names = script.get_names(all_scopes=True, references=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're going for performance here, i wonder if we can do some 'caching' here. for example - if two tests are using the same function(in this case it'll be a jedi name in all_names) we shouldn't have to process it twice

all_defs = script.get_names(all_scopes=True, definitions=True)
all_names_top = script.get_names(all_scopes=True)

top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
except Exception as e:
return {
'status': 'error',
'error_type': 'jedi_script_error',
'error_message': str(e),
'test_file': test_file,
'results': {}
}

test_functions = set()

if config['test_framework'] == "pytest":
for function in functions:
if "[" in function.test_function:
function_name = re.split(r"[\[\]]", function.test_function)[0]
parameters = re.split(r"[\[\]]", function.test_function)[1]
if function_name in top_level_functions:
test_functions.add(
TestFunction(function_name, function.test_class, parameters, function.test_type)
)
elif function.test_function in top_level_functions:
test_functions.add(
TestFunction(function.test_function, function.test_class, None, function.test_type)
)
elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
# Try to match parameterized unittest functions here
base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function)
if base_name in top_level_functions:
test_functions.add(
TestFunction(
function_name=base_name,
test_class=function.test_class,
parameters=function.test_function,
test_type=function.test_type,
)
)

elif config['test_framework'] == "unittest":
functions_to_search = [elem.test_function for elem in functions]
test_suites = {elem.test_class for elem in functions}

matching_names = set(test_suites) & set(top_level_classes.keys())
for matched_name in matching_names:
for def_name in all_defs:
if (
def_name.type == "function"
and def_name.full_name is not None
and f".{matched_name}." in def_name.full_name
):
for function in functions_to_search:
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)

if is_parameterized and new_function == def_name.name:
test_functions.add(
TestFunction(
function_name=def_name.name,
test_class=matched_name,
parameters=parameters,
test_type=functions[0].test_type,
)
)
elif function == def_name.name:
test_functions.add(
TestFunction(
function_name=def_name.name,
test_class=matched_name,
parameters=None,
test_type=functions[0].test_type,
)
)

test_functions_list = list(test_functions)
test_functions_raw = [elem.function_name for elem in test_functions_list]

for name in all_names:
if name.full_name is None:
continue
m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name)
if not m:
continue
scope = m.group(1)
indices = [i for i, x in enumerate(test_functions_raw) if x == scope]
for index in indices:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems inefficient to do name.goto multiple times. am i missing something here? @misrasaurabh1

scope_test_function = test_functions_list[index].function_name
scope_test_class = test_functions_list[index].test_class
scope_parameters = test_functions_list[index].parameters
test_type = test_functions_list[index].test_type
try:
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
continue
if definition and definition[0].type == "function":
definition_path = str(definition[0].module_path)
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(config['project_root_path'] + os.sep)
and definition[0].module_name != name.module_name
and definition[0].full_name is not None
):
if scope_parameters is not None:
if config['test_framework'] == "pytest":
scope_test_function += "[" + scope_parameters + "]"
if config['test_framework'] == "unittest":
scope_test_function += "_" + scope_parameters

# Get module name relative to project root
module_name = module_name_from_file_path(definition[0].module_path, config['project_root_path'])

full_name_without_module_prefix = definition[0].full_name.replace(
definition[0].module_name + ".", "", 1
)
qualified_name_with_modules_from_root = f"{module_name}.{full_name_without_module_prefix}"

# Create a serializable representation of the result
result_entry = {
'test_file': str(test_file),
'test_class': scope_test_class,
'test_function': scope_test_function,
'test_type': test_type,
'line_no': name.line,
'col_no': name.column
}

# Add to local results
if qualified_name_with_modules_from_root not in local_results:
local_results[qualified_name_with_modules_from_root] = []
local_results[qualified_name_with_modules_from_root].append(result_entry)
tests_found_in_file += 1

return {
'status': 'success',
'test_file': test_file,
'tests_found': tests_found_in_file,
'results': dict(local_results) # Convert defaultdict to dict for serialization
}

except Exception as e:
import traceback
return {
'status': 'error',
'error_type': 'general_error',
'error_message': str(e),
'traceback': traceback.format_exc(),
'test_file': test_file,
'results': {}
}

def process_test_files_single_threaded(
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
) -> dict[str, list[FunctionCalledInTest]]:
project_root_path = cfg.project_root_path
Expand Down Expand Up @@ -314,3 +499,125 @@ def process_test_files(
for function, tests in function_to_test_map.items():
deduped_function_to_test_map[function] = list(set(tests))
return deduped_function_to_test_map


def process_test_files(
file_to_test_map: dict[str, list[TestsInFile]], cfg: TestConfig
) -> dict[str, list[FunctionCalledInTest]]:
from multiprocessing import Pool, cpu_count
import os
import pickle

project_root_path = cfg.project_root_path
test_framework = cfg.test_framework

logger.info(f"Starting to process {len(file_to_test_map)} test files with multiprocessing")

# Create a configuration dictionary to pass to worker processes
config_dict = {
'project_root_path': str(project_root_path),
'test_framework': test_framework
}

# Prepare data for processing - create a list of (test_file, functions, config) tuples
process_inputs = []
for test_file, functions in file_to_test_map.items():
# Convert TestsInFile objects to serializable form if needed
serializable_functions = []
for func in functions:
# Ensure test_file is a string (needed for pickling)
if hasattr(func, 'test_file') and not isinstance(func.test_file, str):
func_dict = func._asdict() if hasattr(func, '_asdict') else func.__dict__.copy()
func_dict['test_file'] = str(func_dict['test_file'])
serializable_functions.append(TestsInFile(**func_dict))
else:
serializable_functions.append(func)
process_inputs.append((str(test_file), serializable_functions, config_dict))

# Determine optimal number of processes
max_processes = min(cpu_count() * 2, len(process_inputs), 32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cpu_count * 2? you don't want to create more processes than the number of cpus

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was under the impression this is largely a disk i/o bound task, so multiple processes per core would help

logger.info(f"Using {max_processes} processes for parallel test file processing")

# Create a Pool and process the files
processed_files = 0
error_count = 0
function_to_test_map = defaultdict(list)

# Use smaller chunk size for better load balancing
chunk_size = max(1, len(process_inputs) // (max_processes * 4))

with Pool(processes=max_processes) as pool:
# Use imap_unordered for better performance (we don't care about order)
for i, result in enumerate(pool.imap_unordered(process_file_worker, process_inputs, chunk_size)):
processed_files += 1

# Log progress
if processed_files % 100 == 0 or processed_files == len(process_inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can do something with a progress bar

logger.info(f"Processed {processed_files}/{len(process_inputs)} files")

if result['status'] == 'error':
error_count += 1
logger.warning(f"Error processing file {result['test_file']}: {result['error_message']}")
if 'traceback' in result:
logger.debug(f"Traceback: {result['traceback']}")
continue

# Process results from this file
for qualified_name, test_entries in result['results'].items():
for entry in test_entries:
# Reconstruct FunctionCalledInTest from the serialized data
test_in_file = TestsInFile(
test_file=entry['test_file'],
test_class=entry['test_class'],
test_function=entry['test_function'],
test_type=entry['test_type']
)

position = CodePosition(line_no=entry['line_no'], col_no=entry['col_no'])

function_to_test_map[qualified_name].append(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the initial code, function_to_test_map is a dictionary of lists, but i don't see why we can't make it a dictionary of sets so they're inmmediately deduped.

FunctionCalledInTest(
tests_in_file=test_in_file,
position=position
)
)

logger.info(f"Processing complete. Processed {processed_files}/{len(process_inputs)} files")
logger.info(f"Files with errors: {error_count}")

# Log metrics before deduplication
total_tests_before_dedup = sum(len(tests) for tests in function_to_test_map.values())
logger.info(
f"Found {len(function_to_test_map)} unique functions with {total_tests_before_dedup} total tests before deduplication")

# Deduplicate results
deduped_function_to_test_map = {}
for function, tests in function_to_test_map.items():
# Convert to set and back to list to remove duplicates
# We need to handle custom objects properly
unique_tests = []
seen = set()

for test in tests:
# Create a hashable representation of the test
test_hash = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary, should be hashable already, the pydantic models are frozen

str(test.tests_in_file.test_file),
test.tests_in_file.test_class,
test.tests_in_file.test_function,
test.tests_in_file.test_type,
test.position.line_no,
test.position.col_no
)

if test_hash not in seen:
seen.add(test_hash)
unique_tests.append(test)

deduped_function_to_test_map[function] = unique_tests

# Log metrics after deduplication
total_tests_after_dedup = sum(len(tests) for tests in deduped_function_to_test_map.values())
logger.info(
f"After deduplication: {len(deduped_function_to_test_map)} unique functions with {total_tests_after_dedup} total tests")

return deduped_function_to_test_map
4 changes: 2 additions & 2 deletions codeflash/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
__version__ = "0.10.3"
__version_tuple__ = (0, 10, 3)
__version__ = "0.10.3.post7.dev0+86aa1cd6"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont change this version

__version_tuple__ = (0, 10, 3, "post7", "dev0", "86aa1cd6")
Loading