Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 89 additions & 40 deletions codeflash/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import json
import pickle
import subprocess
import time

import sys
from argparse import ArgumentParser
from pathlib import Path
Expand All @@ -24,6 +26,7 @@
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.tracing.pytest_parallelization import pytest_split

if TYPE_CHECKING:
from argparse import Namespace
Expand Down Expand Up @@ -86,51 +89,97 @@ def main(args: Namespace | None = None) -> ArgumentParser:
config, found_config_path = parse_config_file(parsed_args.codeflash_config)
project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path)
if len(unknown_args) > 0:
args_dict = {
"functions": parsed_args.only_functions,
"disable": False,
"project_root": str(project_root),
"max_function_count": parsed_args.max_function_count,
"timeout": parsed_args.tracer_timeout,
"progname": unknown_args[0],
"config": config,
"module": parsed_args.module,
}
try:
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
args_dict = {
"result_pickle_file_path": str(result_pickle_file_path),
"output": str(parsed_args.outfile),
"functions": parsed_args.only_functions,
"disable": False,
"project_root": str(project_root),
"max_function_count": parsed_args.max_function_count,
"timeout": parsed_args.tracer_timeout,
"command": " ".join(sys.argv),
"progname": unknown_args[0],
"config": config,
"module": parsed_args.module,
}

subprocess.run(
[
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "tracing" / "tracing_new_process.py",
*sys.argv,
json.dumps(args_dict),
],
cwd=Path.cwd(),
check=False,
)
try:
with result_pickle_file_path.open(mode="rb") as f:
data = pickle.load(f)
except Exception:
console.print("❌ Failed to trace. Exiting...")
sys.exit(1)
finally:
result_pickle_file_path.unlink(missing_ok=True)

replay_test_path = data["replay_test_file_path"]
if not parsed_args.trace_only and replay_test_path is not None:
pytest_splits = []
test_paths = []
replay_test_paths = []
if parsed_args.module and unknown_args[0] == "pytest":
pytest_splits, test_paths = pytest_split(unknown_args[1:])

if len(pytest_splits) > 1:
processes = []
test_paths_set = set(test_paths)
result_pickle_file_paths = []
for i, test_split in enumerate(pytest_splits, start=1):
result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl"))
result_pickle_file_paths.append(result_pickle_file_path)
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
outpath = parsed_args.outfile
outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}"
args_dict["output"] = str(outpath)
updated_sys_argv = []
for elem in sys.argv:
if elem in test_paths_set:
updated_sys_argv.extend(test_split)
else:
updated_sys_argv.append(elem)
args_dict["command"] = " ".join(updated_sys_argv)
processes.append(
subprocess.Popen(
[
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "tracing" / "tracing_new_process.py",
*updated_sys_argv,
json.dumps(args_dict),
],
cwd=Path.cwd(),
)
)
for process in processes:
process.wait()
for result_pickle_file_path in result_pickle_file_paths:
try:
with result_pickle_file_path.open(mode="rb") as f:
data = pickle.load(f)
replay_test_paths.append(str(data["replay_test_file_path"]))
except Exception:
console.print("❌ Failed to trace. Exiting...")
sys.exit(1)
finally:
result_pickle_file_path.unlink(missing_ok=True)
else:
result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl"))
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
args_dict["output"] = str(parsed_args.outfile)
args_dict["command"] = " ".join(sys.argv)

subprocess.run(
[
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "tracing" / "tracing_new_process.py",
*sys.argv,
json.dumps(args_dict),
],
cwd=Path.cwd(),
check=False,
)
try:
with result_pickle_file_path.open(mode="rb") as f:
data = pickle.load(f)
replay_test_paths.append(str(data["replay_test_file_path"]))
except Exception:
console.print("❌ Failed to trace. Exiting...")
sys.exit(1)
finally:
result_pickle_file_path.unlink(missing_ok=True)
if not parsed_args.trace_only and replay_test_paths:
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
from codeflash.cli_cmds.console import paneled_text
from codeflash.telemetry import posthog_cf
from codeflash.telemetry.sentry import init_sentry

sys.argv = ["codeflash", "--replay-test", str(replay_test_path)]

sys.argv = ["codeflash", "--replay-test", *replay_test_paths]
args = parse_args()
paneled_text(
CODEFLASH_LOGO,
Expand All @@ -150,8 +199,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
# Delete the trace file and the replay test file if they exist
if outfile:
outfile.unlink(missing_ok=True)
if replay_test_path:
replay_test_path.unlink(missing_ok=True)
for replay_test_path in replay_test_paths:
Path(replay_test_path).unlink(missing_ok=True)

except BrokenPipeError as exc:
# Prevent "Exception ignored" during interpreter shutdown.
Expand Down
83 changes: 83 additions & 0 deletions codeflash/tracing/pytest_parallelization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import os
from math import ceil
from pathlib import Path
from random import shuffle

def pytest_split(
arguments: list[str], num_splits: int | None = None
) -> tuple[list[list[str]] | None, list[str] | None]:
"""Split pytest test files from a directory into N roughly equal groups for parallel execution.

Args:
arguments: List of arguments passed to pytest
test_directory: Path to directory containing test files
num_splits: Number of groups to split tests into. If None, uses CPU count.

Returns:
List of lists, where each inner list contains test file paths for one group.
Returns single list with all tests if number of test files < CPU cores.

"""
try:
import pytest

parser = pytest.Parser()

pytest_args = parser.parse_known_args(arguments)
test_paths = getattr(pytest_args, "file_or_dir", None)
if not test_paths:
return None, None

except ImportError:
return None, None
test_files = set()

# Find all test_*.py files recursively in the directory
for test_path in test_paths:
_test_path = Path(test_path)
if not _test_path.exists():
return None, None
if _test_path.is_dir():
# Find all test files matching the pattern test_*.py
test_files.update(map(str, _test_path.rglob("test_*.py")))
test_files.update(map(str, _test_path.rglob("*_test.py")))
elif _test_path.is_file():
test_files.add(str(_test_path))

if not test_files:
return [[]], None

# Determine number of splits
if num_splits is None:
num_splits = os.cpu_count() or 4

#randomize to increase chances of all splits being balanced
test_files = list(test_files)
shuffle(test_files)

# Ensure each split has at least 4 test files
# If we have fewer test files than 4 * num_splits, reduce num_splits
max_possible_splits = len(test_files) // 4
if max_possible_splits == 0:
return test_files, test_paths

num_splits = min(num_splits, max_possible_splits)

# Calculate chunk size (round up to ensure all files are included)
total_files = len(test_files)
chunk_size = ceil(total_files / num_splits)

# Initialize result groups
result_groups = [[] for _ in range(num_splits)]

# Distribute files across groups
for i, test_file in enumerate(test_files):
group_index = i // chunk_size
# Ensure we don't exceed the number of groups (edge case handling)
if group_index >= num_splits:
group_index = num_splits - 1
result_groups[group_index].append(test_file)

return result_groups, test_paths
Loading