Skip to content

Commit 28f50cc

Browse files
Merge branch 'main' of github.com:codeflash-ai/codeflash into fix/duplicate-global-assignments-when-reverting-helpers
2 parents 9c8256a + 674e69e commit 28f50cc

File tree

11 files changed

+387
-102
lines changed

11 files changed

+387
-102
lines changed

codeflash/api/aiservice.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def make_ai_service_request(
8181
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8282
return response
8383

84+
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
85+
candidates: list[OptimizedCandidate] = []
86+
for opt in optimizations_json:
87+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
88+
if not code.code_strings:
89+
continue
90+
candidates.append(
91+
OptimizedCandidate(
92+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
93+
)
94+
)
95+
return candidates
96+
8497
def optimize_python_code( # noqa: D417
8598
self,
8699
source_code: str,
@@ -135,14 +148,7 @@ def optimize_python_code( # noqa: D417
135148
console.rule()
136149
end_time = time.perf_counter()
137150
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
138-
return [
139-
OptimizedCandidate(
140-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
141-
explanation=opt["explanation"],
142-
optimization_id=opt["optimization_id"],
143-
)
144-
for opt in optimizations_json
145-
]
151+
return self._get_valid_candidates(optimizations_json)
146152
try:
147153
error = response.json()["error"]
148154
except Exception:
@@ -205,14 +211,7 @@ def optimize_python_code_line_profiler( # noqa: D417
205211
optimizations_json = response.json()["optimizations"]
206212
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
207213
console.rule()
208-
return [
209-
OptimizedCandidate(
210-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
211-
explanation=opt["explanation"],
212-
optimization_id=opt["optimization_id"],
213-
)
214-
for opt in optimizations_json
215-
]
214+
return self._get_valid_candidates(optimizations_json)
216215
try:
217216
error = response.json()["error"]
218217
except Exception:
@@ -262,14 +261,17 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
262261
refined_optimizations = response.json()["refinements"]
263262
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
264263
console.rule()
264+
265+
refinements = self._get_valid_candidates(refined_optimizations)
265266
return [
266267
OptimizedCandidate(
267-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
268-
explanation=opt["explanation"],
269-
optimization_id=opt["optimization_id"][:-4] + "refi",
268+
source_code=c.source_code,
269+
explanation=c.explanation,
270+
optimization_id=c.optimization_id[:-4] + "refi",
270271
)
271-
for opt in refined_optimizations
272+
for c in refinements
272273
]
274+
273275
try:
274276
error = response.json()["error"]
275277
except Exception:

codeflash/code_utils/code_extractor.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Optional
67

78
import libcst as cst
@@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120

120121
return updated_node
121122

123+
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
"""Find the position of the last import statement in the top-level of the module."""
125+
insert_index = 0
126+
for i, stmt in enumerate(updated_node.body):
127+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
128+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
129+
)
130+
131+
is_conditional_import = isinstance(stmt, cst.If) and all(
132+
isinstance(inner, cst.SimpleStatementLine)
133+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
134+
for inner in stmt.body.body
135+
)
136+
137+
if is_top_level_import or is_conditional_import:
138+
insert_index = i + 1
139+
140+
# Stop scanning once we reach a class or function definition.
141+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+
# Without this check, a stray import later in the file
143+
# would incorrectly shift our insertion index below actual code definitions.
144+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
145+
break
146+
147+
return insert_index
148+
122149
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
123150
# Add any new assignments that weren't in the original file
124151
new_statements = list(updated_node.body)
@@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131158
]
132159

133160
if assignments_to_append:
134-
# Add a blank line before appending new assignments if needed
135-
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
136-
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
137-
new_statements.pop() # Remove the Pass statement but keep the empty line
138-
139-
# Add the new assignments
140-
new_statements.extend(
141-
[
142-
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
143-
for assignment in assignments_to_append
144-
]
145-
)
161+
# after last top-level imports
162+
insert_index = self._find_insertion_index(updated_node)
163+
164+
assignment_lines = [
165+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
166+
for assignment in assignments_to_append
167+
]
168+
169+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
170+
171+
# Add a blank line after the last assignment if needed
172+
after_index = insert_index + len(assignment_lines)
173+
if after_index < len(new_statements):
174+
next_stmt = new_statements[after_index]
175+
# If there's no empty line, add one
176+
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
177+
if not has_empty:
178+
new_statements[after_index] = next_stmt.with_changes(
179+
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
180+
)
146181

147182
return updated_node.with_changes(body=new_statements)
148183

@@ -341,6 +376,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
341376
new_added_global_statements = extract_global_statements(src_module_code)
342377
existing_global_statements = extract_global_statements(dst_module_code)
343378

379+
# make sure we don't have any staments applited multiple times in the global level.
344380
unique_global_statements = [
345381
stmt
346382
for stmt in new_added_global_statements

codeflash/code_utils/code_replacer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def replace_function_definitions_in_module(
412412
module_abspath: Path,
413413
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
414414
project_root_path: Path,
415+
global_assignments_added_before: bool = False, # noqa: FBT001, FBT002
415416
) -> bool:
416417
source_code: str = module_abspath.read_text(encoding="utf8")
417418
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
@@ -421,7 +422,7 @@ def replace_function_definitions_in_module(
421422
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
422423
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
423424
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
424-
add_global_assignments(code_to_apply, source_code),
425+
add_global_assignments(code_to_apply, source_code) if not global_assignments_added_before else source_code,
425426
function_names,
426427
code_to_apply,
427428
module_abspath,

codeflash/context/unused_definition_remover.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def revert_unused_helper_functions(
537537
module_abspath=file_path,
538538
preexisting_objects=set(), # Empty set since we're reverting
539539
project_root_path=project_root,
540+
global_assignments_added_before=True, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
540541
)
541542

542543
if reverted_code:

codeflash/lsp/beta.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def initialize_function_optimization(
110110

111111
if count == 0:
112112
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
113-
cleanup_the_optimizer(server)
113+
server.cleanup_the_optimizer()
114114
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
115115

116116
fto = optimizable_funcs.popitem()[1][0]
@@ -217,6 +217,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
217217

218218

219219
@server.feature("performFunctionOptimization")
220+
@server.thread()
220221
def perform_function_optimization( # noqa: PLR0911
221222
server: CodeflashLanguageServer, params: FunctionOptimizationParams
222223
) -> dict[str, str]:
@@ -337,14 +338,4 @@ def perform_function_optimization( # noqa: PLR0911
337338
"explanation": best_optimization.explanation_v2,
338339
}
339340
finally:
340-
cleanup_the_optimizer(server)
341-
342-
343-
def cleanup_the_optimizer(server: CodeflashLanguageServer) -> None:
344-
server.optimizer.cleanup_temporary_paths()
345-
# restore args and test cfg
346-
if server.optimizer.original_args_and_test_cfg:
347-
server.optimizer.args, server.optimizer.test_cfg = server.optimizer.original_args_and_test_cfg
348-
server.optimizer.args.function = None
349-
server.optimizer.current_worktree = None
350-
server.optimizer.current_function_optimizer = None
341+
server.cleanup_the_optimizer()

codeflash/lsp/server.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

3+
import sys
34
from pathlib import Path
4-
from typing import TYPE_CHECKING, Any
5+
from threading import Event
6+
from typing import TYPE_CHECKING, Any, Optional, TextIO
57

68
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
79
from pygls import uris
810
from pygls.protocol import LanguageServerProtocol, lsp_method
9-
from pygls.server import LanguageServer
11+
from pygls.server import LanguageServer, StdOutTransportAdapter, aio_readline
1012

1113
if TYPE_CHECKING:
1214
from lsprotocol.types import InitializeParams, InitializeResult
@@ -81,3 +83,39 @@ def show_message_log(self, message: str, message_type: str) -> None:
8183
# Send log message to client (appears in output channel)
8284
log_params = LogMessageParams(type=lsp_message_type, message=message)
8385
self.lsp.notify("window/logMessage", log_params)
86+
87+
def cleanup_the_optimizer(self) -> None:
88+
try:
89+
self.optimizer.cleanup_temporary_paths()
90+
# restore args and test cfg
91+
if self.optimizer.original_args_and_test_cfg:
92+
self.optimizer.args, self.optimizer.test_cfg = self.optimizer.original_args_and_test_cfg
93+
self.optimizer.args.function = None
94+
self.optimizer.current_worktree = None
95+
self.optimizer.current_function_optimizer = None
96+
except Exception:
97+
self.show_message_log("Failed to cleanup optimizer", "Error")
98+
99+
def start_io(self, stdin: Optional[TextIO] = None, stdout: Optional[TextIO] = None) -> None:
100+
self.show_message_log("Starting IO server", "Info")
101+
102+
self._stop_event = Event()
103+
transport = StdOutTransportAdapter(stdin or sys.stdin.buffer, stdout or sys.stdout.buffer)
104+
self.lsp.connection_made(transport)
105+
try:
106+
self.loop.run_until_complete(
107+
aio_readline(
108+
self.loop,
109+
self.thread_pool_executor,
110+
self._stop_event,
111+
stdin or sys.stdin.buffer,
112+
self.lsp.data_received,
113+
)
114+
)
115+
except BrokenPipeError:
116+
self.show_message_log("Connection to the client is lost! Shutting down the server.", "Error")
117+
except (KeyboardInterrupt, SystemExit):
118+
pass
119+
finally:
120+
self.cleanup_the_optimizer()
121+
self.shutdown()

codeflash/models/models.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
22+
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
2323
from pydantic.dataclasses import dataclass
2424

2525
from codeflash.cli_cmds.console import console, logger
@@ -239,10 +239,14 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
239239
"""
240240
matches = markdown_pattern.findall(markdown_code)
241241
results = CodeStringsMarkdown()
242-
for file_path, code in matches:
243-
path = file_path.strip()
244-
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
245-
return results
242+
try:
243+
for file_path, code in matches:
244+
path = file_path.strip()
245+
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
246+
return results # noqa: TRY300
247+
except ValidationError:
248+
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
249+
return CodeStringsMarkdown()
246250

247251

248252
class CodeOptimizationContext(BaseModel):

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,7 @@ def process_review(
13541354
return
13551355

13561356
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
1357+
logger.info("Reverting code and helpers...")
13571358
self.write_code_and_helpers(
13581359
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
13591360
)

0 commit comments

Comments
 (0)