Skip to content

Commit 207612c

Browse files
authored
Merge branch 'main' into standalone-fto-async
2 parents 982914f + 24fb636 commit 207612c

File tree

7 files changed

+340
-47
lines changed

7 files changed

+340
-47
lines changed

codeflash/api/aiservice.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def make_ai_service_request(
8181
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8282
return response
8383

84+
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
85+
candidates: list[OptimizedCandidate] = []
86+
for opt in optimizations_json:
87+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
88+
if not code.code_strings:
89+
continue
90+
candidates.append(
91+
OptimizedCandidate(
92+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
93+
)
94+
)
95+
return candidates
96+
8497
def optimize_python_code( # noqa: D417
8598
self,
8699
source_code: str,
@@ -135,14 +148,7 @@ def optimize_python_code( # noqa: D417
135148
console.rule()
136149
end_time = time.perf_counter()
137150
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
138-
return [
139-
OptimizedCandidate(
140-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
141-
explanation=opt["explanation"],
142-
optimization_id=opt["optimization_id"],
143-
)
144-
for opt in optimizations_json
145-
]
151+
return self._get_valid_candidates(optimizations_json)
146152
try:
147153
error = response.json()["error"]
148154
except Exception:
@@ -205,14 +211,7 @@ def optimize_python_code_line_profiler( # noqa: D417
205211
optimizations_json = response.json()["optimizations"]
206212
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
207213
console.rule()
208-
return [
209-
OptimizedCandidate(
210-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
211-
explanation=opt["explanation"],
212-
optimization_id=opt["optimization_id"],
213-
)
214-
for opt in optimizations_json
215-
]
214+
return self._get_valid_candidates(optimizations_json)
216215
try:
217216
error = response.json()["error"]
218217
except Exception:
@@ -262,14 +261,17 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
262261
refined_optimizations = response.json()["refinements"]
263262
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
264263
console.rule()
264+
265+
refinements = self._get_valid_candidates(refined_optimizations)
265266
return [
266267
OptimizedCandidate(
267-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
268-
explanation=opt["explanation"],
269-
optimization_id=opt["optimization_id"][:-4] + "refi",
268+
source_code=c.source_code,
269+
explanation=c.explanation,
270+
optimization_id=c.optimization_id[:-4] + "refi",
270271
)
271-
for opt in refined_optimizations
272+
for c in refinements
272273
]
274+
273275
try:
274276
error = response.json()["error"]
275277
except Exception:

codeflash/code_utils/code_extractor.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Optional
67

78
import libcst as cst
@@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120

120121
return updated_node
121122

123+
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
"""Find the position of the last import statement in the top-level of the module."""
125+
insert_index = 0
126+
for i, stmt in enumerate(updated_node.body):
127+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
128+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
129+
)
130+
131+
is_conditional_import = isinstance(stmt, cst.If) and all(
132+
isinstance(inner, cst.SimpleStatementLine)
133+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
134+
for inner in stmt.body.body
135+
)
136+
137+
if is_top_level_import or is_conditional_import:
138+
insert_index = i + 1
139+
140+
# Stop scanning once we reach a class or function definition.
141+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+
# Without this check, a stray import later in the file
143+
# would incorrectly shift our insertion index below actual code definitions.
144+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
145+
break
146+
147+
return insert_index
148+
122149
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
123150
# Add any new assignments that weren't in the original file
124151
new_statements = list(updated_node.body)
@@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131158
]
132159

133160
if assignments_to_append:
134-
# Add a blank line before appending new assignments if needed
135-
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
136-
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
137-
new_statements.pop() # Remove the Pass statement but keep the empty line
138-
139-
# Add the new assignments
140-
new_statements.extend(
141-
[
142-
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
143-
for assignment in assignments_to_append
144-
]
145-
)
161+
# after last top-level imports
162+
insert_index = self._find_insertion_index(updated_node)
163+
164+
assignment_lines = [
165+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
166+
for assignment in assignments_to_append
167+
]
168+
169+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
170+
171+
# Add a blank line after the last assignment if needed
172+
after_index = insert_index + len(assignment_lines)
173+
if after_index < len(new_statements):
174+
next_stmt = new_statements[after_index]
175+
# If there's no empty line, add one
176+
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
177+
if not has_empty:
178+
new_statements[after_index] = next_stmt.with_changes(
179+
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
180+
)
146181

147182
return updated_node.with_changes(body=new_statements)
148183

codeflash/models/models.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
22+
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
2323
from pydantic.dataclasses import dataclass
2424

2525
from codeflash.cli_cmds.console import console, logger
@@ -239,10 +239,14 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
239239
"""
240240
matches = markdown_pattern.findall(markdown_code)
241241
results = CodeStringsMarkdown()
242-
for file_path, code in matches:
243-
path = file_path.strip()
244-
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
245-
return results
242+
try:
243+
for file_path, code in matches:
244+
path = file_path.strip()
245+
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
246+
return results # noqa: TRY300
247+
except ValidationError:
248+
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
249+
return CodeStringsMarkdown()
246250

247251

248252
class CodeOptimizationContext(BaseModel):

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,7 @@ def process_review(
13451345
return
13461346

13471347
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
1348+
logger.info("Reverting code and helpers...")
13481349
self.write_code_and_helpers(
13491350
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
13501351
)

0 commit comments

Comments
 (0)