Skip to content

Commit 65a5b6a

Browse files
authored
Merge branch 'main' into saga4/misc_lsp_12_sept
2 parents 856fe60 + 4bcd81c commit 65a5b6a

File tree

15 files changed

+1543
-841
lines changed

15 files changed

+1543
-841
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335335
return updated_node
336336

337337

338-
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
338+
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
339339
"""Extract global statements from source code."""
340340
module = cst.parse_module(source_code)
341341
collector = GlobalStatementCollector()
342342
module.visit(collector)
343-
return collector.global_statements
343+
return module, collector.global_statements
344344

345345

346346
def find_last_import_line(target_code: str) -> int:
@@ -373,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373

374374

375375
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
376-
non_assignment_global_statements = extract_global_statements(src_module_code)
376+
src_module, new_added_global_statements = extract_global_statements(src_module_code)
377+
dst_module, existing_global_statements = extract_global_statements(dst_module_code)
377378

378-
# Find the last import line in target
379-
last_import_line = find_last_import_line(dst_module_code)
380-
381-
# Parse the target code
382-
target_module = cst.parse_module(dst_module_code)
383-
384-
# Create transformer to insert non_assignment_global_statements
385-
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
386-
#
387-
# # Apply transformation
388-
modified_module = target_module.visit(transformer)
389-
dst_module_code = modified_module.code
390-
391-
# Parse the code
392-
original_module = cst.parse_module(dst_module_code)
393-
new_module = cst.parse_module(src_module_code)
379+
unique_global_statements = []
380+
for stmt in new_added_global_statements:
381+
if any(
382+
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
383+
):
384+
continue
385+
unique_global_statements.append(stmt)
386+
387+
mod_dst_code = dst_module_code
388+
# Insert unique global statements if any
389+
if unique_global_statements:
390+
last_import_line = find_last_import_line(dst_module_code)
391+
# Reuse already-parsed dst_module
392+
transformer = ImportInserter(unique_global_statements, last_import_line)
393+
# Use visit inplace, don't parse again
394+
modified_module = dst_module.visit(transformer)
395+
mod_dst_code = modified_module.code
396+
# Parse the code after insertion
397+
original_module = cst.parse_module(mod_dst_code)
398+
else:
399+
# No new statements to insert, reuse already-parsed dst_module
400+
original_module = dst_module
394401

402+
# Parse the src_module_code once only (already done above: src_module)
395403
# Collect assignments from the new file
396404
new_collector = GlobalAssignmentCollector()
397-
new_module.visit(new_collector)
405+
src_module.visit(new_collector)
406+
# Only create transformer if there are assignments to insert/transform
407+
if not new_collector.assignments: # nothing to transform
408+
return mod_dst_code
398409

399-
# Transform the original file
410+
# Transform the original destination module
400411
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
401412
transformed_module = original_module.visit(transformer)
402413

codeflash/code_utils/code_replacer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,17 @@ def replace_function_definitions_in_module(
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415+
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
415416
) -> bool:
416417
source_code: str = module_abspath.read_text(encoding="utf8")
417418
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
419+
418420
new_code: str = replace_functions_and_add_imports(
419-
add_global_assignments(code_to_apply, source_code),
421+
# adding the global assignments before replacing the code, not after
422+
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
423+
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
424+
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
425+
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,
420426
function_names,
421427
code_to_apply,
422428
module_abspath,

codeflash/code_utils/shell_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
1616
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
1717
else:
18-
SHELL_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=[\'"]?(cf-[^\s"]+)[\'"]$', re.MULTILINE)
18+
SHELL_RC_EXPORT_PATTERN = re.compile(
19+
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
20+
)
1921
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
2022

2123

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Version checking utilities for codeflash."""
2+
3+
from __future__ import annotations
4+
5+
import time
6+
7+
import requests
8+
from packaging import version
9+
10+
from codeflash.cli_cmds.console import console, logger
11+
from codeflash.version import __version__
12+
13+
# Simple cache to avoid checking too frequently
14+
_version_cache = {"version": '0.0.0', "timestamp": float(0)}
15+
_cache_duration = 3600 # 1 hour cache
16+
17+
18+
def get_latest_version_from_pypi() -> str | None:
19+
"""Get the latest version of codeflash from PyPI.
20+
21+
Returns:
22+
The latest version string from PyPI, or None if the request fails.
23+
24+
"""
25+
# Check cache first
26+
current_time = time.time()
27+
if _version_cache["version"] is not None and current_time - _version_cache["timestamp"] < _cache_duration:
28+
return _version_cache["version"]
29+
30+
try:
31+
response = requests.get("https://pypi.org/pypi/codeflash/json", timeout=2)
32+
if response.status_code == 200:
33+
data = response.json()
34+
latest_version = data["info"]["version"]
35+
36+
# Update cache
37+
_version_cache["version"] = latest_version
38+
_version_cache["timestamp"] = current_time
39+
40+
return latest_version
41+
logger.debug(f"Failed to fetch version from PyPI: {response.status_code}")
42+
return None # noqa: TRY300
43+
except requests.RequestException as e:
44+
logger.debug(f"Network error fetching version from PyPI: {e}")
45+
return None
46+
except (KeyError, ValueError) as e:
47+
logger.debug(f"Invalid response format from PyPI: {e}")
48+
return None
49+
except Exception as e:
50+
logger.debug(f"Unexpected error fetching version from PyPI: {e}")
51+
return None
52+
53+
54+
def check_for_newer_minor_version() -> None:
55+
"""Check if a newer minor version is available on PyPI and notify the user.
56+
57+
This function compares the current version with the latest version on PyPI.
58+
If a newer minor version is available, it prints an informational message
59+
suggesting the user upgrade.
60+
"""
61+
latest_version = get_latest_version_from_pypi()
62+
63+
if not latest_version:
64+
return
65+
66+
try:
67+
current_parsed = version.parse(__version__)
68+
latest_parsed = version.parse(latest_version)
69+
70+
# Check if there's a newer minor version available
71+
# We only notify for minor version updates, not patch updates
72+
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
73+
logger.warning(
74+
f"A newer version({latest_version}) of Codeflash is available, please update soon!"
75+
)
76+
77+
except version.InvalidVersion as e:
78+
logger.debug(f"Invalid version format: {e}")
79+
return

codeflash/context/unused_definition_remover.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def revert_unused_helper_functions(
537537
module_abspath=file_path,
538538
preexisting_objects=set(), # Empty set since we're reverting
539539
project_root_path=project_root,
540+
should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
540541
)
541542

542543
if reverted_code:

codeflash/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from codeflash.cli_cmds.console import paneled_text
1212
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
1313
from codeflash.code_utils.config_parser import parse_config_file
14+
from codeflash.code_utils.version_check import check_for_newer_minor_version
1415
from codeflash.telemetry import posthog_cf
1516
from codeflash.telemetry.sentry import init_sentry
1617

@@ -21,12 +22,15 @@ def main() -> None:
2122
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
2223
)
2324
args = parse_args()
25+
26+
# Check for newer version for all commands
27+
check_for_newer_minor_version()
28+
2429
if args.command:
30+
disable_telemetry = False
2531
if args.config_file and Path.exists(args.config_file):
2632
pyproject_config, _ = parse_config_file(args.config_file)
2733
disable_telemetry = pyproject_config.get("disable_telemetry", False)
28-
else:
29-
disable_telemetry = False
3034
init_sentry(not disable_telemetry, exclude_errors=True)
3135
posthog_cf.initialize_posthog(not disable_telemetry)
3236
args.func()

codeflash/optimization/function_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from codeflash.code_utils.env_utils import get_pr_number
5555
from codeflash.code_utils.formatter import format_code, sort_imports
56+
from codeflash.code_utils.git_utils import git_root_dir
5657
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
5758
from codeflash.code_utils.line_profile_utils import add_decorator_imports
5859
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
@@ -822,7 +823,10 @@ def reformat_code_and_helpers(
822823
return new_code, new_helper_code
823824

824825
def replace_function_and_helpers_with_optimized_code(
825-
self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str
826+
self,
827+
code_context: CodeOptimizationContext,
828+
optimized_code: CodeStringsMarkdown,
829+
original_helper_code: dict[Path, str],
826830
) -> bool:
827831
did_update = False
828832
read_writable_functions_by_file_path = defaultdict(set)
@@ -1302,11 +1306,13 @@ def process_review(
13021306
"coverage_message": coverage_message,
13031307
"replay_tests": replay_tests,
13041308
"concolic_tests": concolic_tests,
1305-
"root_dir": self.project_root,
13061309
}
13071310

13081311
raise_pr = not self.args.no_pr
13091312

1313+
if raise_pr or self.args.staging_review:
1314+
data["root_dir"] = git_root_dir()
1315+
13101316
if raise_pr and not self.args.staging_review:
13111317
data["git_remote"] = self.args.git_remote
13121318
check_create_pr(**data)

0 commit comments

Comments
 (0)