Skip to content

Commit 5938547

Browse files
authored
Merge branch 'main' into mihika_pr_631
2 parents 69a46d7 + 3dcf7a3 commit 5938547

File tree

12 files changed

+1250
-839
lines changed

12 files changed

+1250
-839
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335335
return updated_node
336336

337337

338-
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
338+
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
339339
"""Extract global statements from source code."""
340340
module = cst.parse_module(source_code)
341341
collector = GlobalStatementCollector()
342342
module.visit(collector)
343-
return collector.global_statements
343+
return module, collector.global_statements
344344

345345

346346
def find_last_import_line(target_code: str) -> int:
@@ -373,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373

374374

375375
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
376-
non_assignment_global_statements = extract_global_statements(src_module_code)
376+
src_module, new_added_global_statements = extract_global_statements(src_module_code)
377+
dst_module, existing_global_statements = extract_global_statements(dst_module_code)
377378

378-
# Find the last import line in target
379-
last_import_line = find_last_import_line(dst_module_code)
380-
381-
# Parse the target code
382-
target_module = cst.parse_module(dst_module_code)
383-
384-
# Create transformer to insert non_assignment_global_statements
385-
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
386-
#
387-
# # Apply transformation
388-
modified_module = target_module.visit(transformer)
389-
dst_module_code = modified_module.code
390-
391-
# Parse the code
392-
original_module = cst.parse_module(dst_module_code)
393-
new_module = cst.parse_module(src_module_code)
379+
unique_global_statements = []
380+
for stmt in new_added_global_statements:
381+
if any(
382+
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
383+
):
384+
continue
385+
unique_global_statements.append(stmt)
386+
387+
mod_dst_code = dst_module_code
388+
# Insert unique global statements if any
389+
if unique_global_statements:
390+
last_import_line = find_last_import_line(dst_module_code)
391+
# Reuse already-parsed dst_module
392+
transformer = ImportInserter(unique_global_statements, last_import_line)
393+
# Use visit inplace, don't parse again
394+
modified_module = dst_module.visit(transformer)
395+
mod_dst_code = modified_module.code
396+
# Parse the code after insertion
397+
original_module = cst.parse_module(mod_dst_code)
398+
else:
399+
# No new statements to insert, reuse already-parsed dst_module
400+
original_module = dst_module
394401

402+
# Parse the src_module_code once only (already done above: src_module)
395403
# Collect assignments from the new file
396404
new_collector = GlobalAssignmentCollector()
397-
new_module.visit(new_collector)
405+
src_module.visit(new_collector)
406+
# Only create transformer if there are assignments to insert/transform
407+
if not new_collector.assignments: # nothing to transform
408+
return mod_dst_code
398409

399-
# Transform the original file
410+
# Transform the original destination module
400411
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
401412
transformed_module = original_module.visit(transformer)
402413

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,17 @@ def replace_function_definitions_in_module(
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415+
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
415416
) -> bool:
416417
source_code: str = module_abspath.read_text(encoding="utf8")
417418
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
419+
418420
new_code: str = replace_functions_and_add_imports(
419-
add_global_assignments(code_to_apply, source_code),
421+
# adding the global assignments before replacing the code, not after
422+
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
423+
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
424+
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
425+
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,
420426
function_names,
421427
code_to_apply,
422428
module_abspath,

codeflash/code_utils/shell_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
1616
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
1717
else:
18-
SHELL_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=[\'"]?(cf-[^\s"]+)[\'"]$', re.MULTILINE)
18+
SHELL_RC_EXPORT_PATTERN = re.compile(
19+
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
20+
)
1921
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
2022

2123

codeflash/context/unused_definition_remover.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def revert_unused_helper_functions(
537537
module_abspath=file_path,
538538
preexisting_objects=set(), # Empty set since we're reverting
539539
project_root_path=project_root,
540+
should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
540541
)
541542

542543
if reverted_code:

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from codeflash.code_utils.env_utils import get_pr_number
5555
from codeflash.code_utils.formatter import format_code, sort_imports
56+
from codeflash.code_utils.git_utils import git_root_dir
5657
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
5758
from codeflash.code_utils.line_profile_utils import add_decorator_imports
5859
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
@@ -820,7 +821,10 @@ def reformat_code_and_helpers(
820821
return new_code, new_helper_code
821822

822823
def replace_function_and_helpers_with_optimized_code(
823-
self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str
824+
self,
825+
code_context: CodeOptimizationContext,
826+
optimized_code: CodeStringsMarkdown,
827+
original_helper_code: dict[Path, str],
824828
) -> bool:
825829
did_update = False
826830
read_writable_functions_by_file_path = defaultdict(set)
@@ -1298,11 +1302,13 @@ def process_review(
12981302
"coverage_message": coverage_message,
12991303
"replay_tests": replay_tests,
13001304
"concolic_tests": concolic_tests,
1301-
"root_dir": self.project_root,
13021305
}
13031306

13041307
raise_pr = not self.args.no_pr
13051308

1309+
if raise_pr or self.args.staging_review:
1310+
data["root_dir"] = git_root_dir()
1311+
13061312
if raise_pr and not self.args.staging_review:
13071313
data["git_remote"] = self.args.git_remote
13081314
check_create_pr(**data)

codeflash/tracer.py

Lines changed: 87 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from codeflash.code_utils.code_utils import get_run_tmp_file
2525
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
2626
from codeflash.code_utils.config_parser import parse_config_file
27+
from codeflash.tracing.pytest_parallelization import pytest_split
2728

2829
if TYPE_CHECKING:
2930
from argparse import Namespace
@@ -86,51 +87,97 @@ def main(args: Namespace | None = None) -> ArgumentParser:
8687
config, found_config_path = parse_config_file(parsed_args.codeflash_config)
8788
project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path)
8889
if len(unknown_args) > 0:
90+
args_dict = {
91+
"functions": parsed_args.only_functions,
92+
"disable": False,
93+
"project_root": str(project_root),
94+
"max_function_count": parsed_args.max_function_count,
95+
"timeout": parsed_args.tracer_timeout,
96+
"progname": unknown_args[0],
97+
"config": config,
98+
"module": parsed_args.module,
99+
}
89100
try:
90-
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
91-
args_dict = {
92-
"result_pickle_file_path": str(result_pickle_file_path),
93-
"output": str(parsed_args.outfile),
94-
"functions": parsed_args.only_functions,
95-
"disable": False,
96-
"project_root": str(project_root),
97-
"max_function_count": parsed_args.max_function_count,
98-
"timeout": parsed_args.tracer_timeout,
99-
"command": " ".join(sys.argv),
100-
"progname": unknown_args[0],
101-
"config": config,
102-
"module": parsed_args.module,
103-
}
104-
105-
subprocess.run(
106-
[
107-
SAFE_SYS_EXECUTABLE,
108-
Path(__file__).parent / "tracing" / "tracing_new_process.py",
109-
*sys.argv,
110-
json.dumps(args_dict),
111-
],
112-
cwd=Path.cwd(),
113-
check=False,
114-
)
115-
try:
116-
with result_pickle_file_path.open(mode="rb") as f:
117-
data = pickle.load(f)
118-
except Exception:
119-
console.print("❌ Failed to trace. Exiting...")
120-
sys.exit(1)
121-
finally:
122-
result_pickle_file_path.unlink(missing_ok=True)
123-
124-
replay_test_path = data["replay_test_file_path"]
125-
if not parsed_args.trace_only and replay_test_path is not None:
101+
pytest_splits = []
102+
test_paths = []
103+
replay_test_paths = []
104+
if parsed_args.module and unknown_args[0] == "pytest":
105+
pytest_splits, test_paths = pytest_split(unknown_args[1:])
106+
107+
if len(pytest_splits) > 1:
108+
processes = []
109+
test_paths_set = set(test_paths)
110+
result_pickle_file_paths = []
111+
for i, test_split in enumerate(pytest_splits, start=1):
112+
result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl"))
113+
result_pickle_file_paths.append(result_pickle_file_path)
114+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
115+
outpath = parsed_args.outfile
116+
outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}"
117+
args_dict["output"] = str(outpath)
118+
updated_sys_argv = []
119+
for elem in sys.argv:
120+
if elem in test_paths_set:
121+
updated_sys_argv.extend(test_split)
122+
else:
123+
updated_sys_argv.append(elem)
124+
args_dict["command"] = " ".join(updated_sys_argv)
125+
processes.append(
126+
subprocess.Popen(
127+
[
128+
SAFE_SYS_EXECUTABLE,
129+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
130+
*updated_sys_argv,
131+
json.dumps(args_dict),
132+
],
133+
cwd=Path.cwd(),
134+
)
135+
)
136+
for process in processes:
137+
process.wait()
138+
for result_pickle_file_path in result_pickle_file_paths:
139+
try:
140+
with result_pickle_file_path.open(mode="rb") as f:
141+
data = pickle.load(f)
142+
replay_test_paths.append(str(data["replay_test_file_path"]))
143+
except Exception:
144+
console.print("❌ Failed to trace. Exiting...")
145+
sys.exit(1)
146+
finally:
147+
result_pickle_file_path.unlink(missing_ok=True)
148+
else:
149+
result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl"))
150+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
151+
args_dict["output"] = str(parsed_args.outfile)
152+
args_dict["command"] = " ".join(sys.argv)
153+
154+
subprocess.run(
155+
[
156+
SAFE_SYS_EXECUTABLE,
157+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
158+
*sys.argv,
159+
json.dumps(args_dict),
160+
],
161+
cwd=Path.cwd(),
162+
check=False,
163+
)
164+
try:
165+
with result_pickle_file_path.open(mode="rb") as f:
166+
data = pickle.load(f)
167+
replay_test_paths.append(str(data["replay_test_file_path"]))
168+
except Exception:
169+
console.print("❌ Failed to trace. Exiting...")
170+
sys.exit(1)
171+
finally:
172+
result_pickle_file_path.unlink(missing_ok=True)
173+
if not parsed_args.trace_only and replay_test_paths:
126174
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
127175
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
128176
from codeflash.cli_cmds.console import paneled_text
129177
from codeflash.telemetry import posthog_cf
130178
from codeflash.telemetry.sentry import init_sentry
131179

132-
sys.argv = ["codeflash", "--replay-test", str(replay_test_path)]
133-
180+
sys.argv = ["codeflash", "--replay-test", *replay_test_paths]
134181
args = parse_args()
135182
paneled_text(
136183
CODEFLASH_LOGO,
@@ -150,8 +197,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
150197
# Delete the trace file and the replay test file if they exist
151198
if outfile:
152199
outfile.unlink(missing_ok=True)
153-
if replay_test_path:
154-
replay_test_path.unlink(missing_ok=True)
200+
for replay_test_path in replay_test_paths:
201+
Path(replay_test_path).unlink(missing_ok=True)
155202

156203
except BrokenPipeError as exc:
157204
# Prevent "Exception ignored" during interpreter shutdown.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from math import ceil
5+
from pathlib import Path
6+
from random import shuffle
7+
8+
9+
def pytest_split(
10+
arguments: list[str], num_splits: int | None = None
11+
) -> tuple[list[list[str]] | None, list[str] | None]:
12+
"""Split pytest test files from a directory into N roughly equal groups for parallel execution.
13+
14+
Args:
15+
arguments: List of arguments passed to pytest
16+
test_directory: Path to directory containing test files
17+
num_splits: Number of groups to split tests into. If None, uses CPU count.
18+
19+
Returns:
20+
List of lists, where each inner list contains test file paths for one group.
21+
Returns single list with all tests if number of test files < CPU cores.
22+
23+
"""
24+
try:
25+
import pytest
26+
27+
parser = pytest.Parser()
28+
29+
pytest_args = parser.parse_known_args(arguments)
30+
test_paths = getattr(pytest_args, "file_or_dir", None)
31+
if not test_paths:
32+
return None, None
33+
34+
except ImportError:
35+
return None, None
36+
test_files = set()
37+
38+
# Find all test_*.py files recursively in the directory
39+
for test_path in test_paths:
40+
_test_path = Path(test_path)
41+
if not _test_path.exists():
42+
return None, None
43+
if _test_path.is_dir():
44+
# Find all test files matching the pattern test_*.py
45+
test_files.update(map(str, _test_path.rglob("test_*.py")))
46+
test_files.update(map(str, _test_path.rglob("*_test.py")))
47+
elif _test_path.is_file():
48+
test_files.add(str(_test_path))
49+
50+
if not test_files:
51+
return [[]], None
52+
53+
# Determine number of splits
54+
if num_splits is None:
55+
num_splits = os.cpu_count() or 4
56+
57+
# randomize to increase chances of all splits being balanced
58+
test_files = list(test_files)
59+
shuffle(test_files)
60+
61+
# Ensure each split has at least 4 test files
62+
# If we have fewer test files than 4 * num_splits, reduce num_splits
63+
max_possible_splits = len(test_files) // 4
64+
if max_possible_splits == 0:
65+
return test_files, test_paths
66+
67+
num_splits = min(num_splits, max_possible_splits)
68+
69+
# Calculate chunk size (round up to ensure all files are included)
70+
total_files = len(test_files)
71+
chunk_size = ceil(total_files / num_splits)
72+
73+
# Initialize result groups
74+
result_groups = [[] for _ in range(num_splits)]
75+
76+
# Distribute files across groups
77+
for i, test_file in enumerate(test_files):
78+
group_index = i // chunk_size
79+
# Ensure we don't exceed the number of groups (edge case handling)
80+
if group_index >= num_splits:
81+
group_index = num_splits - 1
82+
result_groups[group_index].append(test_file)
83+
84+
return result_groups, test_paths

codeflash/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These version placeholders will be replaced by uv-dynamic-versioning during build.
2-
__version__ = "0.16.6"
2+
__version__ = "0.16.7"

0 commit comments

Comments
 (0)