Skip to content

Commit 95114bd

Browse files
authored
Merge branch 'main' into ranker
2 parents dec3c1a + 4bcd81c commit 95114bd

File tree

13 files changed

+1534
-840
lines changed

13 files changed

+1534
-840
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335335
return updated_node
336336

337337

338-
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
338+
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
339339
"""Extract global statements from source code."""
340340
module = cst.parse_module(source_code)
341341
collector = GlobalStatementCollector()
342342
module.visit(collector)
343-
return collector.global_statements
343+
return module, collector.global_statements
344344

345345

346346
def find_last_import_line(target_code: str) -> int:
@@ -373,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373

374374

375375
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
376-
non_assignment_global_statements = extract_global_statements(src_module_code)
376+
src_module, new_added_global_statements = extract_global_statements(src_module_code)
377+
dst_module, existing_global_statements = extract_global_statements(dst_module_code)
377378

378-
# Find the last import line in target
379-
last_import_line = find_last_import_line(dst_module_code)
380-
381-
# Parse the target code
382-
target_module = cst.parse_module(dst_module_code)
383-
384-
# Create transformer to insert non_assignment_global_statements
385-
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
386-
#
387-
# # Apply transformation
388-
modified_module = target_module.visit(transformer)
389-
dst_module_code = modified_module.code
390-
391-
# Parse the code
392-
original_module = cst.parse_module(dst_module_code)
393-
new_module = cst.parse_module(src_module_code)
379+
unique_global_statements = []
380+
for stmt in new_added_global_statements:
381+
if any(
382+
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
383+
):
384+
continue
385+
unique_global_statements.append(stmt)
386+
387+
mod_dst_code = dst_module_code
388+
# Insert unique global statements if any
389+
if unique_global_statements:
390+
last_import_line = find_last_import_line(dst_module_code)
391+
# Reuse already-parsed dst_module
392+
transformer = ImportInserter(unique_global_statements, last_import_line)
393+
# Use visit inplace, don't parse again
394+
modified_module = dst_module.visit(transformer)
395+
mod_dst_code = modified_module.code
396+
# Parse the code after insertion
397+
original_module = cst.parse_module(mod_dst_code)
398+
else:
399+
# No new statements to insert, reuse already-parsed dst_module
400+
original_module = dst_module
394401

402+
# Parse the src_module_code once only (already done above: src_module)
395403
# Collect assignments from the new file
396404
new_collector = GlobalAssignmentCollector()
397-
new_module.visit(new_collector)
405+
src_module.visit(new_collector)
406+
# Only create transformer if there are assignments to insert/transform
407+
if not new_collector.assignments: # nothing to transform
408+
return mod_dst_code
398409

399-
# Transform the original file
410+
# Transform the original destination module
400411
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
401412
transformed_module = original_module.visit(transformer)
402413

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,17 @@ def replace_function_definitions_in_module(
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415+
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
415416
) -> bool:
416417
source_code: str = module_abspath.read_text(encoding="utf8")
417418
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
419+
418420
new_code: str = replace_functions_and_add_imports(
419-
add_global_assignments(code_to_apply, source_code),
421+
# adding the global assignments before replacing the code, not after
422+
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
423+
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
424+
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
425+
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,
420426
function_names,
421427
code_to_apply,
422428
module_abspath,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Version checking utilities for codeflash."""
2+
3+
from __future__ import annotations
4+
5+
import time
6+
7+
import requests
8+
from packaging import version
9+
10+
from codeflash.cli_cmds.console import console, logger
11+
from codeflash.version import __version__
12+
13+
# Simple cache to avoid checking too frequently
14+
_version_cache = {"version": '0.0.0', "timestamp": float(0)}
15+
_cache_duration = 3600 # 1 hour cache
16+
17+
18+
def get_latest_version_from_pypi() -> str | None:
19+
"""Get the latest version of codeflash from PyPI.
20+
21+
Returns:
22+
The latest version string from PyPI, or None if the request fails.
23+
24+
"""
25+
# Check cache first
26+
current_time = time.time()
27+
if _version_cache["version"] is not None and current_time - _version_cache["timestamp"] < _cache_duration:
28+
return _version_cache["version"]
29+
30+
try:
31+
response = requests.get("https://pypi.org/pypi/codeflash/json", timeout=2)
32+
if response.status_code == 200:
33+
data = response.json()
34+
latest_version = data["info"]["version"]
35+
36+
# Update cache
37+
_version_cache["version"] = latest_version
38+
_version_cache["timestamp"] = current_time
39+
40+
return latest_version
41+
logger.debug(f"Failed to fetch version from PyPI: {response.status_code}")
42+
return None # noqa: TRY300
43+
except requests.RequestException as e:
44+
logger.debug(f"Network error fetching version from PyPI: {e}")
45+
return None
46+
except (KeyError, ValueError) as e:
47+
logger.debug(f"Invalid response format from PyPI: {e}")
48+
return None
49+
except Exception as e:
50+
logger.debug(f"Unexpected error fetching version from PyPI: {e}")
51+
return None
52+
53+
54+
def check_for_newer_minor_version() -> None:
55+
"""Check if a newer minor version is available on PyPI and notify the user.
56+
57+
This function compares the current version with the latest version on PyPI.
58+
If a newer minor version is available, it prints an informational message
59+
suggesting the user upgrade.
60+
"""
61+
latest_version = get_latest_version_from_pypi()
62+
63+
if not latest_version:
64+
return
65+
66+
try:
67+
current_parsed = version.parse(__version__)
68+
latest_parsed = version.parse(latest_version)
69+
70+
# Check if there's a newer minor version available
71+
# We only notify for minor version updates, not patch updates
72+
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
73+
logger.warning(
74+
f"A newer version({latest_version}) of Codeflash is available, please update soon!"
75+
)
76+
77+
except version.InvalidVersion as e:
78+
logger.debug(f"Invalid version format: {e}")
79+
return

codeflash/context/unused_definition_remover.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def revert_unused_helper_functions(
537537
module_abspath=file_path,
538538
preexisting_objects=set(), # Empty set since we're reverting
539539
project_root_path=project_root,
540+
should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
540541
)
541542

542543
if reverted_code:

codeflash/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from codeflash.cli_cmds.console import paneled_text
1212
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
1313
from codeflash.code_utils.config_parser import parse_config_file
14+
from codeflash.code_utils.version_check import check_for_newer_minor_version
1415
from codeflash.telemetry import posthog_cf
1516
from codeflash.telemetry.sentry import init_sentry
1617

@@ -21,12 +22,15 @@ def main() -> None:
2122
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
2223
)
2324
args = parse_args()
25+
26+
# Check for newer version for all commands
27+
check_for_newer_minor_version()
28+
2429
if args.command:
30+
disable_telemetry = False
2531
if args.config_file and Path.exists(args.config_file):
2632
pyproject_config, _ = parse_config_file(args.config_file)
2733
disable_telemetry = pyproject_config.get("disable_telemetry", False)
28-
else:
29-
disable_telemetry = False
3034
init_sentry(not disable_telemetry, exclude_errors=True)
3135
posthog_cf.initialize_posthog(not disable_telemetry)
3236
args.func()

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
)
5555
from codeflash.code_utils.env_utils import get_pr_number
5656
from codeflash.code_utils.formatter import format_code, sort_imports
57+
from codeflash.code_utils.git_utils import git_root_dir
5758
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
5859
from codeflash.code_utils.line_profile_utils import add_decorator_imports
5960
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
@@ -848,7 +849,10 @@ def reformat_code_and_helpers(
848849
return new_code, new_helper_code
849850

850851
def replace_function_and_helpers_with_optimized_code(
851-
self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str
852+
self,
853+
code_context: CodeOptimizationContext,
854+
optimized_code: CodeStringsMarkdown,
855+
original_helper_code: dict[Path, str],
852856
) -> bool:
853857
did_update = False
854858
read_writable_functions_by_file_path = defaultdict(set)
@@ -1326,11 +1330,13 @@ def process_review(
13261330
"coverage_message": coverage_message,
13271331
"replay_tests": replay_tests,
13281332
"concolic_tests": concolic_tests,
1329-
"root_dir": self.project_root,
13301333
}
13311334

13321335
raise_pr = not self.args.no_pr
13331336

1337+
if raise_pr or self.args.staging_review:
1338+
data["root_dir"] = git_root_dir()
1339+
13341340
if raise_pr and not self.args.staging_review:
13351341
data["git_remote"] = self.args.git_remote
13361342
check_create_pr(**data)

codeflash/tracer.py

Lines changed: 87 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from codeflash.code_utils.code_utils import get_run_tmp_file
2525
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
2626
from codeflash.code_utils.config_parser import parse_config_file
27+
from codeflash.tracing.pytest_parallelization import pytest_split
2728

2829
if TYPE_CHECKING:
2930
from argparse import Namespace
@@ -86,51 +87,97 @@ def main(args: Namespace | None = None) -> ArgumentParser:
8687
config, found_config_path = parse_config_file(parsed_args.codeflash_config)
8788
project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path)
8889
if len(unknown_args) > 0:
90+
args_dict = {
91+
"functions": parsed_args.only_functions,
92+
"disable": False,
93+
"project_root": str(project_root),
94+
"max_function_count": parsed_args.max_function_count,
95+
"timeout": parsed_args.tracer_timeout,
96+
"progname": unknown_args[0],
97+
"config": config,
98+
"module": parsed_args.module,
99+
}
89100
try:
90-
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
91-
args_dict = {
92-
"result_pickle_file_path": str(result_pickle_file_path),
93-
"output": str(parsed_args.outfile),
94-
"functions": parsed_args.only_functions,
95-
"disable": False,
96-
"project_root": str(project_root),
97-
"max_function_count": parsed_args.max_function_count,
98-
"timeout": parsed_args.tracer_timeout,
99-
"command": " ".join(sys.argv),
100-
"progname": unknown_args[0],
101-
"config": config,
102-
"module": parsed_args.module,
103-
}
104-
105-
subprocess.run(
106-
[
107-
SAFE_SYS_EXECUTABLE,
108-
Path(__file__).parent / "tracing" / "tracing_new_process.py",
109-
*sys.argv,
110-
json.dumps(args_dict),
111-
],
112-
cwd=Path.cwd(),
113-
check=False,
114-
)
115-
try:
116-
with result_pickle_file_path.open(mode="rb") as f:
117-
data = pickle.load(f)
118-
except Exception:
119-
console.print("❌ Failed to trace. Exiting...")
120-
sys.exit(1)
121-
finally:
122-
result_pickle_file_path.unlink(missing_ok=True)
123-
124-
replay_test_path = data["replay_test_file_path"]
125-
if not parsed_args.trace_only and replay_test_path is not None:
101+
pytest_splits = []
102+
test_paths = []
103+
replay_test_paths = []
104+
if parsed_args.module and unknown_args[0] == "pytest":
105+
pytest_splits, test_paths = pytest_split(unknown_args[1:])
106+
107+
if len(pytest_splits) > 1:
108+
processes = []
109+
test_paths_set = set(test_paths)
110+
result_pickle_file_paths = []
111+
for i, test_split in enumerate(pytest_splits, start=1):
112+
result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl"))
113+
result_pickle_file_paths.append(result_pickle_file_path)
114+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
115+
outpath = parsed_args.outfile
116+
outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}"
117+
args_dict["output"] = str(outpath)
118+
updated_sys_argv = []
119+
for elem in sys.argv:
120+
if elem in test_paths_set:
121+
updated_sys_argv.extend(test_split)
122+
else:
123+
updated_sys_argv.append(elem)
124+
args_dict["command"] = " ".join(updated_sys_argv)
125+
processes.append(
126+
subprocess.Popen(
127+
[
128+
SAFE_SYS_EXECUTABLE,
129+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
130+
*updated_sys_argv,
131+
json.dumps(args_dict),
132+
],
133+
cwd=Path.cwd(),
134+
)
135+
)
136+
for process in processes:
137+
process.wait()
138+
for result_pickle_file_path in result_pickle_file_paths:
139+
try:
140+
with result_pickle_file_path.open(mode="rb") as f:
141+
data = pickle.load(f)
142+
replay_test_paths.append(str(data["replay_test_file_path"]))
143+
except Exception:
144+
console.print("❌ Failed to trace. Exiting...")
145+
sys.exit(1)
146+
finally:
147+
result_pickle_file_path.unlink(missing_ok=True)
148+
else:
149+
result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl"))
150+
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
151+
args_dict["output"] = str(parsed_args.outfile)
152+
args_dict["command"] = " ".join(sys.argv)
153+
154+
subprocess.run(
155+
[
156+
SAFE_SYS_EXECUTABLE,
157+
Path(__file__).parent / "tracing" / "tracing_new_process.py",
158+
*sys.argv,
159+
json.dumps(args_dict),
160+
],
161+
cwd=Path.cwd(),
162+
check=False,
163+
)
164+
try:
165+
with result_pickle_file_path.open(mode="rb") as f:
166+
data = pickle.load(f)
167+
replay_test_paths.append(str(data["replay_test_file_path"]))
168+
except Exception:
169+
console.print("❌ Failed to trace. Exiting...")
170+
sys.exit(1)
171+
finally:
172+
result_pickle_file_path.unlink(missing_ok=True)
173+
if not parsed_args.trace_only and replay_test_paths:
126174
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
127175
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
128176
from codeflash.cli_cmds.console import paneled_text
129177
from codeflash.telemetry import posthog_cf
130178
from codeflash.telemetry.sentry import init_sentry
131179

132-
sys.argv = ["codeflash", "--replay-test", str(replay_test_path)]
133-
180+
sys.argv = ["codeflash", "--replay-test", *replay_test_paths]
134181
args = parse_args()
135182
paneled_text(
136183
CODEFLASH_LOGO,
@@ -150,8 +197,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
150197
# Delete the trace file and the replay test file if they exist
151198
if outfile:
152199
outfile.unlink(missing_ok=True)
153-
if replay_test_path:
154-
replay_test_path.unlink(missing_ok=True)
200+
for replay_test_path in replay_test_paths:
201+
Path(replay_test_path).unlink(missing_ok=True)
155202

156203
except BrokenPipeError as exc:
157204
# Prevent "Exception ignored" during interpreter shutdown.

0 commit comments

Comments
 (0)