Skip to content

Commit 628e004

Browse files
Merge branch 'main' into fix/global-assignments-after-imports
2 parents 49fc884 + cccca40 commit 628e004

File tree

4 files changed

+60
-25
lines changed

4 files changed

+60
-25
lines changed

codeflash/api/aiservice.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def make_ai_service_request(
8181
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8282
return response
8383

84+
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
85+
candidates: list[OptimizedCandidate] = []
86+
for opt in optimizations_json:
87+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
88+
if not code.code_strings:
89+
continue
90+
candidates.append(
91+
OptimizedCandidate(
92+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
93+
)
94+
)
95+
return candidates
96+
8497
def optimize_python_code( # noqa: D417
8598
self,
8699
source_code: str,
@@ -135,14 +148,7 @@ def optimize_python_code( # noqa: D417
135148
console.rule()
136149
end_time = time.perf_counter()
137150
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
138-
return [
139-
OptimizedCandidate(
140-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
141-
explanation=opt["explanation"],
142-
optimization_id=opt["optimization_id"],
143-
)
144-
for opt in optimizations_json
145-
]
151+
return self._get_valid_candidates(optimizations_json)
146152
try:
147153
error = response.json()["error"]
148154
except Exception:
@@ -205,14 +211,7 @@ def optimize_python_code_line_profiler( # noqa: D417
205211
optimizations_json = response.json()["optimizations"]
206212
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
207213
console.rule()
208-
return [
209-
OptimizedCandidate(
210-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
211-
explanation=opt["explanation"],
212-
optimization_id=opt["optimization_id"],
213-
)
214-
for opt in optimizations_json
215-
]
214+
return self._get_valid_candidates(optimizations_json)
216215
try:
217216
error = response.json()["error"]
218217
except Exception:
@@ -262,14 +261,17 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
262261
refined_optimizations = response.json()["refinements"]
263262
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
264263
console.rule()
264+
265+
refinements = self._get_valid_candidates(refined_optimizations)
265266
return [
266267
OptimizedCandidate(
267-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
268-
explanation=opt["explanation"],
269-
optimization_id=opt["optimization_id"][:-4] + "refi",
268+
source_code=c.source_code,
269+
explanation=c.explanation,
270+
optimization_id=c.optimization_id[:-4] + "refi",
270271
)
271-
for opt in refined_optimizations
272+
for c in refinements
272273
]
274+
273275
try:
274276
error = response.json()["error"]
275277
except Exception:

codeflash/models/models.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
22+
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
2323
from pydantic.dataclasses import dataclass
2424

2525
from codeflash.cli_cmds.console import console, logger
@@ -239,10 +239,14 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
239239
"""
240240
matches = markdown_pattern.findall(markdown_code)
241241
results = CodeStringsMarkdown()
242-
for file_path, code in matches:
243-
path = file_path.strip()
244-
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
245-
return results
242+
try:
243+
for file_path, code in matches:
244+
path = file_path.strip()
245+
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
246+
return results # noqa: TRY300
247+
except ValidationError:
248+
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
249+
return CodeStringsMarkdown()
246250

247251

248252
class CodeOptimizationContext(BaseModel):

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,7 @@ def process_review(
13511351
return
13521352

13531353
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
1354+
logger.info("Reverting code and helpers...")
13541355
self.write_code_and_helpers(
13551356
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
13561357
)

tests/test_validate_python_code.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from pydantic import ValidationError
33

4+
from codeflash.api.aiservice import AiServiceClient
45
from codeflash.models.models import CodeString
56

67

@@ -41,3 +42,30 @@ def test_whitespace_only():
4142
whitespace_code = " "
4243
cs = CodeString(code=whitespace_code)
4344
assert cs.code == whitespace_code
45+
46+
def test_generated_candidates_validation():
47+
ai_service = AiServiceClient()
48+
code = """```python:file.py
49+
print name
50+
```"""
51+
mock_generate_candidates = [
52+
{
53+
"source_code": code,
54+
"explanation": "",
55+
"optimization_id": ""
56+
}
57+
]
58+
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
59+
assert len(candidates) == 0
60+
code = """```python:file.py
61+
print('Hello, World!')
62+
```"""
63+
mock_generate_candidates = [
64+
{
65+
"source_code": code,
66+
"explanation": "",
67+
"optimization_id": ""
68+
}
69+
]
70+
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
71+
assert len(candidates) == 1

0 commit comments

Comments
 (0)