diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 46c73f819..cb5f7f58a 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -14,6 +14,8 @@ import json import pickle import subprocess +import time + import sys from argparse import ArgumentParser from pathlib import Path @@ -24,6 +26,7 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_parser import parse_config_file +from codeflash.tracing.pytest_parallelization import pytest_split if TYPE_CHECKING: from argparse import Namespace @@ -86,51 +89,97 @@ def main(args: Namespace | None = None) -> ArgumentParser: config, found_config_path = parse_config_file(parsed_args.codeflash_config) project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path) if len(unknown_args) > 0: + args_dict = { + "functions": parsed_args.only_functions, + "disable": False, + "project_root": str(project_root), + "max_function_count": parsed_args.max_function_count, + "timeout": parsed_args.tracer_timeout, + "progname": unknown_args[0], + "config": config, + "module": parsed_args.module, + } try: - result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl") - args_dict = { - "result_pickle_file_path": str(result_pickle_file_path), - "output": str(parsed_args.outfile), - "functions": parsed_args.only_functions, - "disable": False, - "project_root": str(project_root), - "max_function_count": parsed_args.max_function_count, - "timeout": parsed_args.tracer_timeout, - "command": " ".join(sys.argv), - "progname": unknown_args[0], - "config": config, - "module": parsed_args.module, - } - - subprocess.run( - [ - SAFE_SYS_EXECUTABLE, - Path(__file__).parent / "tracing" / "tracing_new_process.py", - *sys.argv, - json.dumps(args_dict), - ], - cwd=Path.cwd(), - check=False, - ) - try: - with result_pickle_file_path.open(mode="rb") as f: - data = pickle.load(f) - except Exception: - console.print("❌ Failed to trace. Exiting...") - sys.exit(1) - finally: - result_pickle_file_path.unlink(missing_ok=True) - - replay_test_path = data["replay_test_file_path"] - if not parsed_args.trace_only and replay_test_path is not None: + pytest_splits = [] + test_paths = [] + replay_test_paths = [] + if parsed_args.module and unknown_args[0] == "pytest": + pytest_splits, test_paths = pytest_split(unknown_args[1:]) + + if len(pytest_splits) > 1: + processes = [] + test_paths_set = set(test_paths) + result_pickle_file_paths = [] + for i, test_split in enumerate(pytest_splits, start=1): + result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl")) + result_pickle_file_paths.append(result_pickle_file_path) + args_dict["result_pickle_file_path"] = str(result_pickle_file_path) + outpath = parsed_args.outfile + outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}" + args_dict["output"] = str(outpath) + updated_sys_argv = [] + for elem in sys.argv: + if elem in test_paths_set: + updated_sys_argv.extend(test_split) + else: + updated_sys_argv.append(elem) + args_dict["command"] = " ".join(updated_sys_argv) + processes.append( + subprocess.Popen( + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "tracing" / "tracing_new_process.py", + *updated_sys_argv, + json.dumps(args_dict), + ], + cwd=Path.cwd(), + ) + ) + for process in processes: + process.wait() + for result_pickle_file_path in result_pickle_file_paths: + try: + with result_pickle_file_path.open(mode="rb") as f: + data = pickle.load(f) + replay_test_paths.append(str(data["replay_test_file_path"])) + except Exception: + console.print("❌ Failed to trace. Exiting...") + sys.exit(1) + finally: + result_pickle_file_path.unlink(missing_ok=True) + else: + result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl")) + args_dict["result_pickle_file_path"] = str(result_pickle_file_path) + args_dict["output"] = str(parsed_args.outfile) + args_dict["command"] = " ".join(sys.argv) + + subprocess.run( + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "tracing" / "tracing_new_process.py", + *sys.argv, + json.dumps(args_dict), + ], + cwd=Path.cwd(), + check=False, + ) + try: + with result_pickle_file_path.open(mode="rb") as f: + data = pickle.load(f) + replay_test_paths.append(str(data["replay_test_file_path"])) + except Exception: + console.print("❌ Failed to trace. Exiting...") + sys.exit(1) + finally: + result_pickle_file_path.unlink(missing_ok=True) + if not parsed_args.trace_only and replay_test_paths: from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO from codeflash.cli_cmds.console import paneled_text from codeflash.telemetry import posthog_cf from codeflash.telemetry.sentry import init_sentry - sys.argv = ["codeflash", "--replay-test", str(replay_test_path)] - + sys.argv = ["codeflash", "--replay-test", *replay_test_paths] args = parse_args() paneled_text( CODEFLASH_LOGO, @@ -150,8 +199,8 @@ def main(args: Namespace | None = None) -> ArgumentParser: # Delete the trace file and the replay test file if they exist if outfile: outfile.unlink(missing_ok=True) - if replay_test_path: - replay_test_path.unlink(missing_ok=True) + for replay_test_path in replay_test_paths: + Path(replay_test_path).unlink(missing_ok=True) except BrokenPipeError as exc: # Prevent "Exception ignored" during interpreter shutdown. diff --git a/codeflash/tracing/pytest_parallelization.py b/codeflash/tracing/pytest_parallelization.py new file mode 100644 index 000000000..b28187174 --- /dev/null +++ b/codeflash/tracing/pytest_parallelization.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import os +from math import ceil +from pathlib import Path +from random import shuffle + +def pytest_split( + arguments: list[str], num_splits: int | None = None +) -> tuple[list[list[str]] | None, list[str] | None]: + """Split pytest test files from a directory into N roughly equal groups for parallel execution. + + Args: + arguments: List of arguments passed to pytest + test_directory: Path to directory containing test files + num_splits: Number of groups to split tests into. If None, uses CPU count. + + Returns: + List of lists, where each inner list contains test file paths for one group. + Returns single list with all tests if number of test files < CPU cores. + + """ + try: + import pytest + + parser = pytest.Parser() + + pytest_args = parser.parse_known_args(arguments) + test_paths = getattr(pytest_args, "file_or_dir", None) + if not test_paths: + return None, None + + except ImportError: + return None, None + test_files = set() + + # Find all test_*.py files recursively in the directory + for test_path in test_paths: + _test_path = Path(test_path) + if not _test_path.exists(): + return None, None + if _test_path.is_dir(): + # Find all test files matching the pattern test_*.py + test_files.update(map(str, _test_path.rglob("test_*.py"))) + test_files.update(map(str, _test_path.rglob("*_test.py"))) + elif _test_path.is_file(): + test_files.add(str(_test_path)) + + if not test_files: + return [[]], None + + # Determine number of splits + if num_splits is None: + num_splits = os.cpu_count() or 4 + + #randomize to increase chances of all splits being balanced + test_files = list(test_files) + shuffle(test_files) + + # Ensure each split has at least 4 test files + # If we have fewer test files than 4 * num_splits, reduce num_splits + max_possible_splits = len(test_files) // 4 + if max_possible_splits == 0: + return test_files, test_paths + + num_splits = min(num_splits, max_possible_splits) + + # Calculate chunk size (round up to ensure all files are included) + total_files = len(test_files) + chunk_size = ceil(total_files / num_splits) + + # Initialize result groups + result_groups = [[] for _ in range(num_splits)] + + # Distribute files across groups + for i, test_file in enumerate(test_files): + group_index = i // chunk_size + # Ensure we don't exceed the number of groups (edge case handling) + if group_index >= num_splits: + group_index = num_splits - 1 + result_groups[group_index].append(test_file) + + return result_groups, test_paths