Skip to content

Commit 07e2243

Browse files
committed
Merge remote-tracking branch 'origin/main' into aseembits93/mcp
2 parents d23c739 + 674e69e commit 07e2243

32 files changed

+820
-330
lines changed

codeflash/api/aiservice.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def make_ai_service_request(
8181
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8282
return response
8383

84+
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
85+
candidates: list[OptimizedCandidate] = []
86+
for opt in optimizations_json:
87+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
88+
if not code.code_strings:
89+
continue
90+
candidates.append(
91+
OptimizedCandidate(
92+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
93+
)
94+
)
95+
return candidates
96+
8497
def optimize_python_code( # noqa: D417
8598
self,
8699
source_code: str,
@@ -135,14 +148,7 @@ def optimize_python_code( # noqa: D417
135148
console.rule()
136149
end_time = time.perf_counter()
137150
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
138-
return [
139-
OptimizedCandidate(
140-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
141-
explanation=opt["explanation"],
142-
optimization_id=opt["optimization_id"],
143-
)
144-
for opt in optimizations_json
145-
]
151+
return self._get_valid_candidates(optimizations_json)
146152
try:
147153
error = response.json()["error"]
148154
except Exception:
@@ -205,14 +211,7 @@ def optimize_python_code_line_profiler( # noqa: D417
205211
optimizations_json = response.json()["optimizations"]
206212
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
207213
console.rule()
208-
return [
209-
OptimizedCandidate(
210-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
211-
explanation=opt["explanation"],
212-
optimization_id=opt["optimization_id"],
213-
)
214-
for opt in optimizations_json
215-
]
214+
return self._get_valid_candidates(optimizations_json)
216215
try:
217216
error = response.json()["error"]
218217
except Exception:
@@ -262,14 +261,17 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
262261
refined_optimizations = response.json()["refinements"]
263262
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
264263
console.rule()
264+
265+
refinements = self._get_valid_candidates(refined_optimizations)
265266
return [
266267
OptimizedCandidate(
267-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
268-
explanation=opt["explanation"],
269-
optimization_id=opt["optimization_id"][:-4] + "refi",
268+
source_code=c.source_code,
269+
explanation=c.explanation,
270+
optimization_id=c.optimization_id[:-4] + "refi",
270271
)
271-
for opt in refined_optimizations
272+
for c in refinements
272273
]
274+
273275
try:
274276
error = response.json()["error"]
275277
except Exception:

codeflash/cli_cmds/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from codeflash.code_utils import env_utils
1111
from codeflash.code_utils.code_utils import exit_with_message
1212
from codeflash.code_utils.config_parser import parse_config_file
13+
from codeflash.lsp.helpers import is_LSP_enabled
1314
from codeflash.version import __version__ as version
1415

1516

@@ -214,6 +215,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
214215
if args.benchmarks_root:
215216
args.benchmarks_root = Path(args.benchmarks_root).resolve()
216217
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
218+
if is_LSP_enabled():
219+
args.all = None
220+
return args
217221
return handle_optimize_all_arg_parsing(args)
218222

219223

codeflash/cli_cmds/cmd_init.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,22 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
165165
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
166166

167167

168+
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> dict[str, Any] | None:
169+
if not pyproject_toml_path.exists():
170+
return None
171+
try:
172+
config, _ = parse_config_file(pyproject_toml_path)
173+
except Exception:
174+
return None
175+
176+
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
177+
return None
178+
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
179+
return None
180+
181+
return config
182+
183+
168184
def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
169185
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.
170186
@@ -173,16 +189,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
173189
from rich.prompt import Confirm
174190

175191
pyproject_toml_path = Path.cwd() / "pyproject.toml"
176-
if not pyproject_toml_path.exists():
177-
return True, None
178-
try:
179-
config, config_file_path = parse_config_file(pyproject_toml_path)
180-
except Exception:
181-
return True, None
182192

183-
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
184-
return True, None
185-
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
193+
config = is_valid_pyproject_toml(pyproject_toml_path)
194+
if config is None:
186195
return True, None
187196

188197
return Confirm.ask(
@@ -978,6 +987,11 @@ def install_github_app(git_remote: str) -> None:
978987
except git.InvalidGitRepositoryError:
979988
click.echo("Skipping GitHub app installation because you're not in a git repository.")
980989
return
990+
991+
if git_remote not in get_git_remotes(git_repo):
992+
click.echo(f"Skipping GitHub app installation, remote ({git_remote}) does not exist in this repository.")
993+
return
994+
981995
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
982996

983997
if is_github_app_installed_on_repo(owner, repo, suppress_errors=True):

codeflash/code_utils/code_extractor.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Optional
67

78
import libcst as cst
@@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120

120121
return updated_node
121122

123+
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
"""Find the position of the last import statement in the top-level of the module."""
125+
insert_index = 0
126+
for i, stmt in enumerate(updated_node.body):
127+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
128+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
129+
)
130+
131+
is_conditional_import = isinstance(stmt, cst.If) and all(
132+
isinstance(inner, cst.SimpleStatementLine)
133+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
134+
for inner in stmt.body.body
135+
)
136+
137+
if is_top_level_import or is_conditional_import:
138+
insert_index = i + 1
139+
140+
# Stop scanning once we reach a class or function definition.
141+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+
# Without this check, a stray import later in the file
143+
# would incorrectly shift our insertion index below actual code definitions.
144+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
145+
break
146+
147+
return insert_index
148+
122149
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
123150
# Add any new assignments that weren't in the original file
124151
new_statements = list(updated_node.body)
@@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131158
]
132159

133160
if assignments_to_append:
134-
# Add a blank line before appending new assignments if needed
135-
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
136-
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
137-
new_statements.pop() # Remove the Pass statement but keep the empty line
138-
139-
# Add the new assignments
140-
new_statements.extend(
141-
[
142-
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
143-
for assignment in assignments_to_append
144-
]
145-
)
161+
# after last top-level imports
162+
insert_index = self._find_insertion_index(updated_node)
163+
164+
assignment_lines = [
165+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
166+
for assignment in assignments_to_append
167+
]
168+
169+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
170+
171+
# Add a blank line after the last assignment if needed
172+
after_index = insert_index + len(assignment_lines)
173+
if after_index < len(new_statements):
174+
next_stmt = new_statements[after_index]
175+
# If there's no empty line, add one
176+
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
177+
if not has_empty:
178+
new_statements[after_index] = next_stmt.with_changes(
179+
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
180+
)
146181

147182
return updated_node.with_changes(body=new_statements)
148183

codeflash/code_utils/git_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
201201

202202
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
203203
repository = git.Repo(worktree_dir, search_parent_directories=True)
204-
repository.git.commit("-am", commit_message, "--no-verify")
204+
repository.git.add(".")
205+
repository.git.commit("-m", commit_message, "--no-verify")
205206

206207

207208
def create_detached_worktree(module_root: Path) -> Optional[Path]:
@@ -218,7 +219,10 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
218219

219220
# Get uncommitted diff from the original repo
220221
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
221-
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
222+
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
223+
uni_diff_text = repository.git.diff(
224+
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
225+
)
222226

223227
if not uni_diff_text.strip():
224228
logger.info("No uncommitted changes to copy to worktree.")
@@ -234,7 +238,7 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
234238
# Apply the patch inside the worktree
235239
try:
236240
subprocess.run(
237-
["git", "apply", "--ignore-space-change", "--ignore-whitespace", patch_path],
241+
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
238242
cwd=worktree_dir,
239243
check=True,
240244
)

codeflash/discovery/functions_to_optimize.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
2828
from codeflash.code_utils.time_utils import humanize_runtime
2929
from codeflash.discovery.discover_unit_tests import discover_unit_tests
30+
from codeflash.lsp.helpers import is_LSP_enabled
3031
from codeflash.models.models import FunctionParent
3132
from codeflash.telemetry.posthog_cf import ph
3233

@@ -168,6 +169,7 @@ def get_functions_to_optimize(
168169
)
169170
functions: dict[str, list[FunctionToOptimize]]
170171
trace_file_path: Path | None = None
172+
is_lsp = is_LSP_enabled()
171173
with warnings.catch_warnings():
172174
warnings.simplefilter(action="ignore", category=SyntaxWarning)
173175
if optimize_all:
@@ -185,6 +187,8 @@ def get_functions_to_optimize(
185187
if only_get_this_function is not None:
186188
split_function = only_get_this_function.split(".")
187189
if len(split_function) > 2:
190+
if is_lsp:
191+
return functions, 0, None
188192
exit_with_message(
189193
"Function name should be in the format 'function_name' or 'class_name.function_name'"
190194
)
@@ -200,6 +204,8 @@ def get_functions_to_optimize(
200204
):
201205
found_function = fn
202206
if found_function is None:
207+
if is_lsp:
208+
return functions, 0, None
203209
exit_with_message(
204210
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
205211
)
@@ -470,6 +476,10 @@ def was_function_previously_optimized(
470476
Tuple of (filtered_functions_dict, remaining_count)
471477
472478
"""
479+
if is_LSP_enabled():
480+
# was_function_previously_optimized is for the checking the optimization duplicates in the github action, no need to do this in the LSP mode
481+
return False
482+
473483
# Check optimization status if repository info is provided
474484
# already_optimized_count = 0
475485
try:

0 commit comments

Comments
 (0)