From 936fa0a9602493a2a055d8535909844818f97bbb Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Wed, 27 Aug 2025 17:53:28 -0700 Subject: [PATCH 1/4] wip Signed-off-by: Saurabh Misra --- codeflash/tracer.py | 130 ++++++++++++++------ codeflash/tracing/pytest_parallelization.py | 81 ++++++++++++ 2 files changed, 172 insertions(+), 39 deletions(-) create mode 100644 codeflash/tracing/pytest_parallelization.py diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 46c73f819..a755aa06d 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -24,6 +24,7 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_parser import parse_config_file +from codeflash.tracing.pytest_parallelization import pytest_split if TYPE_CHECKING: from argparse import Namespace @@ -86,51 +87,102 @@ def main(args: Namespace | None = None) -> ArgumentParser: config, found_config_path = parse_config_file(parsed_args.codeflash_config) project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path) if len(unknown_args) > 0: + args_dict = { + "functions": parsed_args.only_functions, + "disable": False, + "project_root": str(project_root), + "max_function_count": parsed_args.max_function_count, + "timeout": parsed_args.tracer_timeout, + "progname": unknown_args[0], + "config": config, + "module": parsed_args.module, + } try: - result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl") - args_dict = { - "result_pickle_file_path": str(result_pickle_file_path), - "output": str(parsed_args.outfile), - "functions": parsed_args.only_functions, - "disable": False, - "project_root": str(project_root), - "max_function_count": parsed_args.max_function_count, - "timeout": parsed_args.tracer_timeout, - "command": " ".join(sys.argv), - "progname": unknown_args[0], - "config": config, - "module": parsed_args.module, - } - - subprocess.run( - [ - SAFE_SYS_EXECUTABLE, - Path(__file__).parent / "tracing" / "tracing_new_process.py", - *sys.argv, - json.dumps(args_dict), - ], - cwd=Path.cwd(), - check=False, - ) - try: - with result_pickle_file_path.open(mode="rb") as f: - data = pickle.load(f) - except Exception: - console.print("❌ Failed to trace. Exiting...") - sys.exit(1) - finally: - result_pickle_file_path.unlink(missing_ok=True) - - replay_test_path = data["replay_test_file_path"] - if not parsed_args.trace_only and replay_test_path is not None: + pytest_splits = [] + test_paths = [] + replay_test_paths = [] + if parsed_args.module and unknown_args[0] == "pytest": + pytest_splits, test_paths = pytest_split(unknown_args[1:]) + print(pytest_splits) + + if len(pytest_splits) > 1: + processes = [] + test_paths_set = set(test_paths) + result_pickle_file_paths = [] + for i, test_split in enumerate(pytest_splits, start=1): + result_pickle_file_path = get_run_tmp_file(f"tracer_results_file_{i}.pkl") + result_pickle_file_paths.append(result_pickle_file_path) + args_dict["result_pickle_file_path"] = str(result_pickle_file_path) + outpath = parsed_args.outfile + outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}" + args_dict["output"] = str(outpath) + added_paths = False + updated_sys_argv = [] + for elem in sys.argv: + if elem in test_paths_set: + if not added_paths: + updated_sys_argv.extend(test_split) + else: + updated_sys_argv.append(elem) + args_dict["command"] = " ".join(updated_sys_argv) + processes.append( + subprocess.Popen( + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "tracing" / "tracing_new_process.py", + *updated_sys_argv, + json.dumps(args_dict), + ], + cwd=Path.cwd(), + ) + ) + for process in processes: + process.wait() + for result_pickle_file_path in result_pickle_file_paths: + try: + with result_pickle_file_path.open(mode="rb") as f: + data = pickle.load(f) + replay_test_paths.append(str(data["replay_test_file_path"])) + except Exception: + console.print("❌ Failed to trace. Exiting...") + sys.exit(1) + finally: + result_pickle_file_path.unlink(missing_ok=True) + else: + result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl") + args_dict["result_pickle_file_path"] = str(result_pickle_file_path) + args_dict["output"] = str(parsed_args.outfile) + args_dict["command"] = " ".join(sys.argv) + + subprocess.run( + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "tracing" / "tracing_new_process.py", + *sys.argv, + json.dumps(args_dict), + ], + cwd=Path.cwd(), + check=False, + ) + try: + with result_pickle_file_path.open(mode="rb") as f: + data = pickle.load(f) + replay_test_paths.append(str(data["replay_test_file_path"])) + except Exception: + console.print("❌ Failed to trace. Exiting...") + sys.exit(1) + finally: + result_pickle_file_path.unlink(missing_ok=True) + + if not parsed_args.trace_only and replay_test_paths: from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO from codeflash.cli_cmds.console import paneled_text from codeflash.telemetry import posthog_cf from codeflash.telemetry.sentry import init_sentry - sys.argv = ["codeflash", "--replay-test", str(replay_test_path)] - + sys.argv = ["codeflash", "--replay-test", *replay_test_paths] + print(sys.argv) args = parse_args() paneled_text( CODEFLASH_LOGO, @@ -150,7 +202,7 @@ def main(args: Namespace | None = None) -> ArgumentParser: # Delete the trace file and the replay test file if they exist if outfile: outfile.unlink(missing_ok=True) - if replay_test_path: + for replay_test_path in replay_test_paths: replay_test_path.unlink(missing_ok=True) except BrokenPipeError as exc: diff --git a/codeflash/tracing/pytest_parallelization.py b/codeflash/tracing/pytest_parallelization.py new file mode 100644 index 000000000..38bf04aa2 --- /dev/null +++ b/codeflash/tracing/pytest_parallelization.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import os +from math import ceil +from pathlib import Path + + +def pytest_split( + arguments: list[str], num_splits: int | None = None +) -> tuple[list[list[str]] | None, list[str] | None]: + """Split pytest test files from a directory into N roughly equal groups for parallel execution. + + Args: + test_directory: Path to directory containing test files + num_splits: Number of groups to split tests into. If None, uses CPU count. + + Returns: + List of lists, where each inner list contains test file paths for one group. + Returns single list with all tests if number of test files < CPU cores. + + """ + try: + import pytest + + parser = pytest.Parser() + + pytest_args = parser.parse_known_args(arguments) + test_paths = getattr(pytest_args, "file_or_dir", None) + if not test_paths: + return None, None + + except ImportError: + return None, None + test_files = [] + + # Find all test_*.py files recursively in the directory + for test_path in test_paths: + _test_path = Path(test_path) + if not _test_path.exists(): + return None, None + if _test_path.is_dir(): + # Find all test files matching the pattern test_*.py + for test_file in _test_path.rglob("test_*.py"): + test_files.append(str(test_file)) + elif _test_path.is_file(): + test_files.append(str(_test_path)) + + # Sort files for consistent ordering + test_files.sort() + + if not test_files: + return [[]], None + + # Determine number of splits + if num_splits is None: + num_splits = os.cpu_count() or 4 + + # Ensure each split has at least 4 test files + # If we have fewer test files than 4 * num_splits, reduce num_splits + max_possible_splits = len(test_files) // 4 + if max_possible_splits == 0: + return [test_files], test_paths + + num_splits = min(num_splits, max_possible_splits) + + # Calculate chunk size (round up to ensure all files are included) + total_files = len(test_files) + chunk_size = ceil(total_files / num_splits) + + # Initialize result groups + result_groups = [[] for _ in range(num_splits)] + + # Distribute files across groups + for i, test_file in enumerate(test_files): + group_index = i // chunk_size + # Ensure we don't exceed the number of groups (edge case handling) + if group_index >= num_splits: + group_index = num_splits - 1 + result_groups[group_index].append(test_file) + + return result_groups, test_paths From 78519eeb3183945cfd7064f04419a74a88d2cdf1 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 29 Aug 2025 14:20:22 -0700 Subject: [PATCH 2/4] bugfix --- codeflash/tracer.py | 2 +- codeflash/tracing/pytest_parallelization.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index a755aa06d..ed66163ed 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -203,7 +203,7 @@ def main(args: Namespace | None = None) -> ArgumentParser: if outfile: outfile.unlink(missing_ok=True) for replay_test_path in replay_test_paths: - replay_test_path.unlink(missing_ok=True) + Path(replay_test_path).unlink(missing_ok=True) except BrokenPipeError as exc: # Prevent "Exception ignored" during interpreter shutdown. diff --git a/codeflash/tracing/pytest_parallelization.py b/codeflash/tracing/pytest_parallelization.py index 38bf04aa2..b88255e0a 100644 --- a/codeflash/tracing/pytest_parallelization.py +++ b/codeflash/tracing/pytest_parallelization.py @@ -11,6 +11,7 @@ def pytest_split( """Split pytest test files from a directory into N roughly equal groups for parallel execution. Args: + arguments: List of arguments passed to pytest test_directory: Path to directory containing test files num_splits: Number of groups to split tests into. If None, uses CPU count. @@ -40,8 +41,7 @@ def pytest_split( return None, None if _test_path.is_dir(): # Find all test files matching the pattern test_*.py - for test_file in _test_path.rglob("test_*.py"): - test_files.append(str(test_file)) + test_files.extend(map(str, _test_path.rglob("test_*.py"))) elif _test_path.is_file(): test_files.append(str(_test_path)) From d4788b928fd82f43e14d4dfbb387678d6611628e Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 2 Sep 2025 16:24:32 -0700 Subject: [PATCH 3/4] debug measure time --- codeflash/tracer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index ed66163ed..8c32f9521 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -14,6 +14,8 @@ import json import pickle import subprocess +import time + import sys from argparse import ArgumentParser from pathlib import Path @@ -31,6 +33,7 @@ def main(args: Namespace | None = None) -> ArgumentParser: + start = time.time() parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", default="codeflash.trace") parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) @@ -173,7 +176,7 @@ def main(args: Namespace | None = None) -> ArgumentParser: sys.exit(1) finally: result_pickle_file_path.unlink(missing_ok=True) - + print(f"Took {time.time() - start}") if not parsed_args.trace_only and replay_test_paths: from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO From 0dc325af2b8244c553498ffb3789f5cddd353aa8 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Wed, 3 Sep 2025 19:58:55 +0000 Subject: [PATCH 4/4] sets instead of lists --- codeflash/tracer.py | 12 +++--------- codeflash/tracing/pytest_parallelization.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 8c32f9521..cb5f7f58a 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -33,7 +33,6 @@ def main(args: Namespace | None = None) -> ArgumentParser: - start = time.time() parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", default="codeflash.trace") parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) @@ -106,25 +105,22 @@ def main(args: Namespace | None = None) -> ArgumentParser: replay_test_paths = [] if parsed_args.module and unknown_args[0] == "pytest": pytest_splits, test_paths = pytest_split(unknown_args[1:]) - print(pytest_splits) if len(pytest_splits) > 1: processes = [] test_paths_set = set(test_paths) result_pickle_file_paths = [] for i, test_split in enumerate(pytest_splits, start=1): - result_pickle_file_path = get_run_tmp_file(f"tracer_results_file_{i}.pkl") + result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl")) result_pickle_file_paths.append(result_pickle_file_path) args_dict["result_pickle_file_path"] = str(result_pickle_file_path) outpath = parsed_args.outfile outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}" args_dict["output"] = str(outpath) - added_paths = False updated_sys_argv = [] for elem in sys.argv: if elem in test_paths_set: - if not added_paths: - updated_sys_argv.extend(test_split) + updated_sys_argv.extend(test_split) else: updated_sys_argv.append(elem) args_dict["command"] = " ".join(updated_sys_argv) @@ -152,7 +148,7 @@ def main(args: Namespace | None = None) -> ArgumentParser: finally: result_pickle_file_path.unlink(missing_ok=True) else: - result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl") + result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl")) args_dict["result_pickle_file_path"] = str(result_pickle_file_path) args_dict["output"] = str(parsed_args.outfile) args_dict["command"] = " ".join(sys.argv) @@ -176,7 +172,6 @@ def main(args: Namespace | None = None) -> ArgumentParser: sys.exit(1) finally: result_pickle_file_path.unlink(missing_ok=True) - print(f"Took {time.time() - start}") if not parsed_args.trace_only and replay_test_paths: from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO @@ -185,7 +180,6 @@ def main(args: Namespace | None = None) -> ArgumentParser: from codeflash.telemetry.sentry import init_sentry sys.argv = ["codeflash", "--replay-test", *replay_test_paths] - print(sys.argv) args = parse_args() paneled_text( CODEFLASH_LOGO, diff --git a/codeflash/tracing/pytest_parallelization.py b/codeflash/tracing/pytest_parallelization.py index b88255e0a..b28187174 100644 --- a/codeflash/tracing/pytest_parallelization.py +++ b/codeflash/tracing/pytest_parallelization.py @@ -3,7 +3,7 @@ import os from math import ceil from pathlib import Path - +from random import shuffle def pytest_split( arguments: list[str], num_splits: int | None = None @@ -32,7 +32,7 @@ def pytest_split( except ImportError: return None, None - test_files = [] + test_files = set() # Find all test_*.py files recursively in the directory for test_path in test_paths: @@ -41,12 +41,10 @@ def pytest_split( return None, None if _test_path.is_dir(): # Find all test files matching the pattern test_*.py - test_files.extend(map(str, _test_path.rglob("test_*.py"))) + test_files.update(map(str, _test_path.rglob("test_*.py"))) + test_files.update(map(str, _test_path.rglob("*_test.py"))) elif _test_path.is_file(): - test_files.append(str(_test_path)) - - # Sort files for consistent ordering - test_files.sort() + test_files.add(str(_test_path)) if not test_files: return [[]], None @@ -55,11 +53,15 @@ def pytest_split( if num_splits is None: num_splits = os.cpu_count() or 4 + #randomize to increase chances of all splits being balanced + test_files = list(test_files) + shuffle(test_files) + # Ensure each split has at least 4 test files # If we have fewer test files than 4 * num_splits, reduce num_splits max_possible_splits = len(test_files) // 4 if max_possible_splits == 0: - return [test_files], test_paths + return test_files, test_paths num_splits = min(num_splits, max_possible_splits)