Skip to content

Commit c5fbe09

Browse files
Merge branch 'main' into lsp/threaded-optimizer-cleanup
2 parents 85ccaaa + 24fb636 commit c5fbe09

File tree

7 files changed

+435
-91
lines changed

7 files changed

+435
-91
lines changed

codeflash/api/aiservice.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def make_ai_service_request(
8181
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8282
return response
8383

84+
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
85+
candidates: list[OptimizedCandidate] = []
86+
for opt in optimizations_json:
87+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
88+
if not code.code_strings:
89+
continue
90+
candidates.append(
91+
OptimizedCandidate(
92+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
93+
)
94+
)
95+
return candidates
96+
8497
def optimize_python_code( # noqa: D417
8598
self,
8699
source_code: str,
@@ -135,14 +148,7 @@ def optimize_python_code( # noqa: D417
135148
console.rule()
136149
end_time = time.perf_counter()
137150
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
138-
return [
139-
OptimizedCandidate(
140-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
141-
explanation=opt["explanation"],
142-
optimization_id=opt["optimization_id"],
143-
)
144-
for opt in optimizations_json
145-
]
151+
return self._get_valid_candidates(optimizations_json)
146152
try:
147153
error = response.json()["error"]
148154
except Exception:
@@ -205,14 +211,7 @@ def optimize_python_code_line_profiler( # noqa: D417
205211
optimizations_json = response.json()["optimizations"]
206212
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
207213
console.rule()
208-
return [
209-
OptimizedCandidate(
210-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
211-
explanation=opt["explanation"],
212-
optimization_id=opt["optimization_id"],
213-
)
214-
for opt in optimizations_json
215-
]
214+
return self._get_valid_candidates(optimizations_json)
216215
try:
217216
error = response.json()["error"]
218217
except Exception:
@@ -262,14 +261,17 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
262261
refined_optimizations = response.json()["refinements"]
263262
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
264263
console.rule()
264+
265+
refinements = self._get_valid_candidates(refined_optimizations)
265266
return [
266267
OptimizedCandidate(
267-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
268-
explanation=opt["explanation"],
269-
optimization_id=opt["optimization_id"][:-4] + "refi",
268+
source_code=c.source_code,
269+
explanation=c.explanation,
270+
optimization_id=c.optimization_id[:-4] + "refi",
270271
)
271-
for opt in refined_optimizations
272+
for c in refinements
272273
]
274+
273275
try:
274276
error = response.json()["error"]
275277
except Exception:

codeflash/code_utils/code_extractor.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Optional
67

78
import libcst as cst
@@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120

120121
return updated_node
121122

123+
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
"""Find the position of the last import statement in the top-level of the module."""
125+
insert_index = 0
126+
for i, stmt in enumerate(updated_node.body):
127+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
128+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
129+
)
130+
131+
is_conditional_import = isinstance(stmt, cst.If) and all(
132+
isinstance(inner, cst.SimpleStatementLine)
133+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
134+
for inner in stmt.body.body
135+
)
136+
137+
if is_top_level_import or is_conditional_import:
138+
insert_index = i + 1
139+
140+
# Stop scanning once we reach a class or function definition.
141+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+
# Without this check, a stray import later in the file
143+
# would incorrectly shift our insertion index below actual code definitions.
144+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
145+
break
146+
147+
return insert_index
148+
122149
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
123150
# Add any new assignments that weren't in the original file
124151
new_statements = list(updated_node.body)
@@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131158
]
132159

133160
if assignments_to_append:
134-
# Add a blank line before appending new assignments if needed
135-
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
136-
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
137-
new_statements.pop() # Remove the Pass statement but keep the empty line
138-
139-
# Add the new assignments
140-
new_statements.extend(
141-
[
142-
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
143-
for assignment in assignments_to_append
144-
]
145-
)
161+
# after last top-level imports
162+
insert_index = self._find_insertion_index(updated_node)
163+
164+
assignment_lines = [
165+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
166+
for assignment in assignments_to_append
167+
]
168+
169+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
170+
171+
# Add a blank line after the last assignment if needed
172+
after_index = insert_index + len(assignment_lines)
173+
if after_index < len(new_statements):
174+
next_stmt = new_statements[after_index]
175+
# If there's no empty line, add one
176+
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
177+
if not has_empty:
178+
new_statements[after_index] = next_stmt.with_changes(
179+
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
180+
)
146181

147182
return updated_node.with_changes(body=new_statements)
148183

codeflash/models/models.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Annotated, Optional, cast
2020

2121
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
22+
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
2323
from pydantic.dataclasses import dataclass
2424

2525
from codeflash.cli_cmds.console import console, logger
@@ -239,10 +239,14 @@ def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
239239
"""
240240
matches = markdown_pattern.findall(markdown_code)
241241
results = CodeStringsMarkdown()
242-
for file_path, code in matches:
243-
path = file_path.strip()
244-
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
245-
return results
242+
try:
243+
for file_path, code in matches:
244+
path = file_path.strip()
245+
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
246+
return results # noqa: TRY300
247+
except ValidationError:
248+
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
249+
return CodeStringsMarkdown()
246250

247251

248252
class CodeOptimizationContext(BaseModel):

codeflash/optimization/function_optimizer.py

Lines changed: 96 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import ast
44
import concurrent.futures
55
import os
6+
import queue
67
import random
78
import subprocess
89
import time
910
import uuid
10-
from collections import defaultdict, deque
11+
from collections import defaultdict
1112
from pathlib import Path
1213
from typing import TYPE_CHECKING
1314

@@ -104,6 +105,83 @@
104105
from codeflash.verification.verification_utils import TestConfig
105106

106107

108+
class CandidateProcessor:
109+
"""Handles candidate processing using a queue-based approach."""
110+
111+
def __init__(
112+
self,
113+
initial_candidates: list,
114+
future_line_profile_results: concurrent.futures.Future,
115+
future_all_refinements: list,
116+
) -> None:
117+
self.candidate_queue = queue.Queue()
118+
self.line_profiler_done = False
119+
self.refinement_done = False
120+
self.candidate_len = len(initial_candidates)
121+
122+
# Initialize queue with initial candidates
123+
for candidate in initial_candidates:
124+
self.candidate_queue.put(candidate)
125+
126+
self.future_line_profile_results = future_line_profile_results
127+
self.future_all_refinements = future_all_refinements
128+
129+
def get_next_candidate(self) -> OptimizedCandidate | None:
130+
"""Get the next candidate from the queue, handling async results as needed."""
131+
try:
132+
return self.candidate_queue.get_nowait()
133+
except queue.Empty:
134+
return self._handle_empty_queue()
135+
136+
def _handle_empty_queue(self) -> OptimizedCandidate | None:
137+
"""Handle empty queue by checking for pending async results."""
138+
if not self.line_profiler_done:
139+
return self._process_line_profiler_results()
140+
if self.line_profiler_done and not self.refinement_done:
141+
return self._process_refinement_results()
142+
return None # All done
143+
144+
def _process_line_profiler_results(self) -> OptimizedCandidate | None:
145+
"""Process line profiler results and add to queue."""
146+
logger.debug("all candidates processed, await candidates from line profiler")
147+
concurrent.futures.wait([self.future_line_profile_results])
148+
line_profile_results = self.future_line_profile_results.result()
149+
150+
for candidate in line_profile_results:
151+
self.candidate_queue.put(candidate)
152+
153+
self.candidate_len += len(line_profile_results)
154+
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}")
155+
self.line_profiler_done = True
156+
157+
return self.get_next_candidate()
158+
159+
def _process_refinement_results(self) -> OptimizedCandidate | None:
160+
"""Process refinement results and add to queue."""
161+
concurrent.futures.wait(self.future_all_refinements)
162+
refinement_response = []
163+
164+
for future_refinement in self.future_all_refinements:
165+
possible_refinement = future_refinement.result()
166+
if len(possible_refinement) > 0:
167+
refinement_response.append(possible_refinement[0])
168+
169+
for candidate in refinement_response:
170+
self.candidate_queue.put(candidate)
171+
172+
self.candidate_len += len(refinement_response)
173+
logger.info(
174+
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
175+
)
176+
self.refinement_done = True
177+
178+
return self.get_next_candidate()
179+
180+
def is_done(self) -> bool:
181+
"""Check if processing is complete."""
182+
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()
183+
184+
107185
class FunctionOptimizer:
108186
def __init__(
109187
self,
@@ -378,15 +456,13 @@ def determine_best_candidate(
378456
f"{self.function_to_optimize.qualified_name}…"
379457
)
380458
console.rule()
381-
candidates = deque(candidates)
382-
refinement_done = False
383-
line_profiler_done = False
459+
384460
future_all_refinements: list[concurrent.futures.Future] = []
385461
ast_code_to_id = {}
386462
valid_optimizations = []
387463
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
388-
# Start a new thread for AI service request, start loop in main thread
389-
# check if aiservice request is complete, when it is complete, append result to the candidates list
464+
465+
# Start a new thread for AI service request
390466
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
391467
future_line_profile_results = self.executor.submit(
392468
ai_service_client.optimize_python_code_line_profiler,
@@ -401,48 +477,23 @@ def determine_best_candidate(
401477
if self.experiment_id
402478
else None,
403479
)
480+
481+
# Initialize candidate processor
482+
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
404483
candidate_index = 0
405-
original_len = len(candidates)
406-
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407-
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408-
while True:
409-
try:
410-
if len(candidates) > 0:
411-
candidate = candidates.popleft()
412-
else:
413-
if not line_profiler_done:
414-
logger.debug("all candidates processed, await candidates from line profiler")
415-
concurrent.futures.wait([future_line_profile_results])
416-
line_profile_results = future_line_profile_results.result()
417-
candidates.extend(line_profile_results)
418-
original_len += len(line_profile_results)
419-
logger.info(
420-
f"Added results from line profiler to candidates, total candidates now: {original_len}"
421-
)
422-
line_profiler_done = True
423-
continue
424-
if line_profiler_done and not refinement_done:
425-
concurrent.futures.wait(future_all_refinements)
426-
refinement_response = []
427-
for future_refinement in future_all_refinements:
428-
possible_refinement = future_refinement.result()
429-
if len(possible_refinement) > 0: # if the api returns a valid response
430-
refinement_response.append(possible_refinement[0])
431-
candidates.extend(refinement_response)
432-
original_len += len(refinement_response)
433-
logger.info(
434-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
435-
)
436-
refinement_done = True
437-
continue
438-
if line_profiler_done and refinement_done:
439-
logger.debug("everything done, exiting")
440-
break
441484

485+
# Process candidates using queue-based approach
486+
while not processor.is_done():
487+
candidate = processor.get_next_candidate()
488+
if candidate is None:
489+
logger.debug("everything done, exiting")
490+
break
491+
492+
try:
442493
candidate_index += 1
443494
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
444495
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
445-
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
496+
logger.info(f"Optimization candidate {candidate_index}/{processor.candidate_len}:")
446497
code_print(candidate.source_code.flat)
447498
# map ast normalized code to diff len, unnormalized code
448499
# map opt id to the shortest unnormalized code
@@ -467,7 +518,7 @@ def determine_best_candidate(
467518
# check if this code has been evaluated before by checking the ast normalized code string
468519
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
469520
if normalized_code in ast_code_to_id:
470-
logger.warning(
521+
logger.info(
471522
"Current candidate has been encountered before in testing, Skipping optimization candidate."
472523
)
473524
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]
@@ -1300,6 +1351,7 @@ def process_review(
13001351
return
13011352

13021353
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
1354+
logger.info("Reverting code and helpers...")
13031355
self.write_code_and_helpers(
13041356
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
13051357
)

0 commit comments

Comments
 (0)