Skip to content

Commit 936fa0a

Browse files
committed
wip
Signed-off-by: Saurabh Misra <[email protected]>
1 parent 217ced2 commit 936fa0a

File tree

2 files changed

+172
-39
lines changed

2 files changed

+172
-39
lines changed

codeflash/tracer.py

Lines changed: 91 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from codeflash.code_utils.code_utils import get_run_tmp_file
2525
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
2626
from codeflash.code_utils.config_parser import parse_config_file
27+
from codeflash.tracing.pytest_parallelization import pytest_split
2728

2829
if TYPE_CHECKING:
2930
from argparse import Namespace
@@ -86,51 +87,102 @@ def main(args: Namespace | None = None) -> ArgumentParser:
8687
config, found_config_path = parse_config_file(parsed_args.codeflash_config)
8788
project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path)
8889
if len(unknown_args) > 0:
90+
args_dict = {
91+
"functions": parsed_args.only_functions,
92+
"disable": False,
93+
"project_root": str(project_root),
94+
"max_function_count": parsed_args.max_function_count,
95+
"timeout": parsed_args.tracer_timeout,
96+
"progname": unknown_args[0],
97+
"config": config,
98+
"module": parsed_args.module,
99+
}
89100
try:
90-
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
91-
args_dict = {
92-
"result_pickle_file_path": str(result_pickle_file_path),
93-
"output": str(parsed_args.outfile),
94-
"functions": parsed_args.only_functions,
95-
"disable": False,
96-
"project_root": str(project_root),
97-
"max_function_count": parsed_args.max_function_count,
98-
"timeout": parsed_args.tracer_timeout,
99-
"command": " ".join(sys.argv),
100-
"progname": unknown_args[0],
101-
"config": config,
102-
"module": parsed_args.module,
103-
}
104-
105-
subprocess.run(
106-
[
107-
SAFE_SYS_EXECUTABLE,
108-
Path(__file__).parent / "tracing" / "tracing_new_process.py",
109-
*sys.argv,
110-
json.dumps(args_dict),
111-
],
112-
cwd=Path.cwd(),
113-
check=False,
114-
)
115-
try:
116-
with result_pickle_file_path.open(mode="rb") as f:
117-
data = pickle.load(f)
118-
except Exception:
119-
console.print("❌ Failed to trace. Exiting...")
120-
sys.exit(1)
121-
finally:
122-
result_pickle_file_path.unlink(missing_ok=True)
123-
124-
replay_test_path = data["replay_test_file_path"]
125-
if not parsed_args.trace_only and replay_test_path is not None:
101+
pytest_splits = []
102+
test_paths = []
103+
replay_test_paths = []
104+
if parsed_args.module and unknown_args[0] == "pytest":
105+
pytest_splits, test_paths = pytest_split(unknown_args[1:])
106+
print(pytest_splits)
107+
108+
if len(pytest_splits) > 1:
109+
processes = []
110+
test_paths_set = set(test_paths)
111+
result_pickle_file_paths = []
112+
for i, test_split in enumerate(pytest_splits, start=1):
113+
result_pickle_file_path = get_run_tmp_file(f"tracer_results_file_{i}.pkl")
114+
result_pickle_file_paths.append(result_pickle_file_path)
115+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
116+
outpath = parsed_args.outfile
117+
outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}"
118+
args_dict["output"] = str(outpath)
119+
added_paths = False
120+
updated_sys_argv = []
121+
for elem in sys.argv:
122+
if elem in test_paths_set:
123+
if not added_paths:
124+
updated_sys_argv.extend(test_split)
125+
else:
126+
updated_sys_argv.append(elem)
127+
args_dict["command"] = " ".join(updated_sys_argv)
128+
processes.append(
129+
subprocess.Popen(
130+
[
131+
SAFE_SYS_EXECUTABLE,
132+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
133+
*updated_sys_argv,
134+
json.dumps(args_dict),
135+
],
136+
cwd=Path.cwd(),
137+
)
138+
)
139+
for process in processes:
140+
process.wait()
141+
for result_pickle_file_path in result_pickle_file_paths:
142+
try:
143+
with result_pickle_file_path.open(mode="rb") as f:
144+
data = pickle.load(f)
145+
replay_test_paths.append(str(data["replay_test_file_path"]))
146+
except Exception:
147+
console.print("❌ Failed to trace. Exiting...")
148+
sys.exit(1)
149+
finally:
150+
result_pickle_file_path.unlink(missing_ok=True)
151+
else:
152+
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
153+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
154+
args_dict["output"] = str(parsed_args.outfile)
155+
args_dict["command"] = " ".join(sys.argv)
156+
157+
subprocess.run(
158+
[
159+
SAFE_SYS_EXECUTABLE,
160+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
161+
*sys.argv,
162+
json.dumps(args_dict),
163+
],
164+
cwd=Path.cwd(),
165+
check=False,
166+
)
167+
try:
168+
with result_pickle_file_path.open(mode="rb") as f:
169+
data = pickle.load(f)
170+
replay_test_paths.append(str(data["replay_test_file_path"]))
171+
except Exception:
172+
console.print("❌ Failed to trace. Exiting...")
173+
sys.exit(1)
174+
finally:
175+
result_pickle_file_path.unlink(missing_ok=True)
176+
177+
if not parsed_args.trace_only and replay_test_paths:
126178
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
127179
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
128180
from codeflash.cli_cmds.console import paneled_text
129181
from codeflash.telemetry import posthog_cf
130182
from codeflash.telemetry.sentry import init_sentry
131183

132-
sys.argv = ["codeflash", "--replay-test", str(replay_test_path)]
133-
184+
sys.argv = ["codeflash", "--replay-test", *replay_test_paths]
185+
print(sys.argv)
134186
args = parse_args()
135187
paneled_text(
136188
CODEFLASH_LOGO,
@@ -150,7 +202,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
150202
# Delete the trace file and the replay test file if they exist
151203
if outfile:
152204
outfile.unlink(missing_ok=True)
153-
if replay_test_path:
205+
for replay_test_path in replay_test_paths:
154206
replay_test_path.unlink(missing_ok=True)
155207

156208
except BrokenPipeError as exc:
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from math import ceil
5+
from pathlib import Path
6+
7+
8+
def pytest_split(
9+
arguments: list[str], num_splits: int | None = None
10+
) -> tuple[list[list[str]] | None, list[str] | None]:
11+
"""Split pytest test files from a directory into N roughly equal groups for parallel execution.
12+
13+
Args:
14+
test_directory: Path to directory containing test files
15+
num_splits: Number of groups to split tests into. If None, uses CPU count.
16+
17+
Returns:
18+
List of lists, where each inner list contains test file paths for one group.
19+
Returns single list with all tests if number of test files < CPU cores.
20+
21+
"""
22+
try:
23+
import pytest
24+
25+
parser = pytest.Parser()
26+
27+
pytest_args = parser.parse_known_args(arguments)
28+
test_paths = getattr(pytest_args, "file_or_dir", None)
29+
if not test_paths:
30+
return None, None
31+
32+
except ImportError:
33+
return None, None
34+
test_files = []
35+
36+
# Find all test_*.py files recursively in the directory
37+
for test_path in test_paths:
38+
_test_path = Path(test_path)
39+
if not _test_path.exists():
40+
return None, None
41+
if _test_path.is_dir():
42+
# Find all test files matching the pattern test_*.py
43+
for test_file in _test_path.rglob("test_*.py"):
44+
test_files.append(str(test_file))
45+
elif _test_path.is_file():
46+
test_files.append(str(_test_path))
47+
48+
# Sort files for consistent ordering
49+
test_files.sort()
50+
51+
if not test_files:
52+
return [[]], None
53+
54+
# Determine number of splits
55+
if num_splits is None:
56+
num_splits = os.cpu_count() or 4
57+
58+
# Ensure each split has at least 4 test files
59+
# If we have fewer test files than 4 * num_splits, reduce num_splits
60+
max_possible_splits = len(test_files) // 4
61+
if max_possible_splits == 0:
62+
return [test_files], test_paths
63+
64+
num_splits = min(num_splits, max_possible_splits)
65+
66+
# Calculate chunk size (round up to ensure all files are included)
67+
total_files = len(test_files)
68+
chunk_size = ceil(total_files / num_splits)
69+
70+
# Initialize result groups
71+
result_groups = [[] for _ in range(num_splits)]
72+
73+
# Distribute files across groups
74+
for i, test_file in enumerate(test_files):
75+
group_index = i // chunk_size
76+
# Ensure we don't exceed the number of groups (edge case handling)
77+
if group_index >= num_splits:
78+
group_index = num_splits - 1
79+
result_groups[group_index].append(test_file)
80+
81+
return result_groups, test_paths

0 commit comments

Comments
 (0)