Skip to content

Commit 982914f

Browse files
authored
Merge branch 'main' into standalone-fto-async
2 parents 4759f07 + 02b4d65 commit 982914f

File tree

1 file changed

+95
-44
lines changed

1 file changed

+95
-44
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import ast
44
import concurrent.futures
55
import os
6+
import queue
67
import random
78
import subprocess
89
import time
910
import uuid
10-
from collections import defaultdict, deque
11+
from collections import defaultdict
1112
from pathlib import Path
1213
from typing import TYPE_CHECKING
1314

@@ -103,6 +104,83 @@
103104
from codeflash.verification.verification_utils import TestConfig
104105

105106

107+
class CandidateProcessor:
108+
"""Handles candidate processing using a queue-based approach."""
109+
110+
def __init__(
111+
self,
112+
initial_candidates: list,
113+
future_line_profile_results: concurrent.futures.Future,
114+
future_all_refinements: list,
115+
) -> None:
116+
self.candidate_queue = queue.Queue()
117+
self.line_profiler_done = False
118+
self.refinement_done = False
119+
self.candidate_len = len(initial_candidates)
120+
121+
# Initialize queue with initial candidates
122+
for candidate in initial_candidates:
123+
self.candidate_queue.put(candidate)
124+
125+
self.future_line_profile_results = future_line_profile_results
126+
self.future_all_refinements = future_all_refinements
127+
128+
def get_next_candidate(self) -> OptimizedCandidate | None:
129+
"""Get the next candidate from the queue, handling async results as needed."""
130+
try:
131+
return self.candidate_queue.get_nowait()
132+
except queue.Empty:
133+
return self._handle_empty_queue()
134+
135+
def _handle_empty_queue(self) -> OptimizedCandidate | None:
136+
"""Handle empty queue by checking for pending async results."""
137+
if not self.line_profiler_done:
138+
return self._process_line_profiler_results()
139+
if self.line_profiler_done and not self.refinement_done:
140+
return self._process_refinement_results()
141+
return None # All done
142+
143+
def _process_line_profiler_results(self) -> OptimizedCandidate | None:
144+
"""Process line profiler results and add to queue."""
145+
logger.debug("all candidates processed, await candidates from line profiler")
146+
concurrent.futures.wait([self.future_line_profile_results])
147+
line_profile_results = self.future_line_profile_results.result()
148+
149+
for candidate in line_profile_results:
150+
self.candidate_queue.put(candidate)
151+
152+
self.candidate_len += len(line_profile_results)
153+
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}")
154+
self.line_profiler_done = True
155+
156+
return self.get_next_candidate()
157+
158+
def _process_refinement_results(self) -> OptimizedCandidate | None:
159+
"""Process refinement results and add to queue."""
160+
concurrent.futures.wait(self.future_all_refinements)
161+
refinement_response = []
162+
163+
for future_refinement in self.future_all_refinements:
164+
possible_refinement = future_refinement.result()
165+
if len(possible_refinement) > 0:
166+
refinement_response.append(possible_refinement[0])
167+
168+
for candidate in refinement_response:
169+
self.candidate_queue.put(candidate)
170+
171+
self.candidate_len += len(refinement_response)
172+
logger.info(
173+
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
174+
)
175+
self.refinement_done = True
176+
177+
return self.get_next_candidate()
178+
179+
def is_done(self) -> bool:
180+
"""Check if processing is complete."""
181+
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()
182+
183+
106184
class FunctionOptimizer:
107185
def __init__(
108186
self,
@@ -372,15 +450,13 @@ def determine_best_candidate(
372450
f"{self.function_to_optimize.qualified_name}…"
373451
)
374452
console.rule()
375-
candidates = deque(candidates)
376-
refinement_done = False
377-
line_profiler_done = False
453+
378454
future_all_refinements: list[concurrent.futures.Future] = []
379455
ast_code_to_id = {}
380456
valid_optimizations = []
381457
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
382-
# Start a new thread for AI service request, start loop in main thread
383-
# check if aiservice request is complete, when it is complete, append result to the candidates list
458+
459+
# Start a new thread for AI service request
384460
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
385461
future_line_profile_results = self.executor.submit(
386462
ai_service_client.optimize_python_code_line_profiler,
@@ -395,48 +471,23 @@ def determine_best_candidate(
395471
if self.experiment_id
396472
else None,
397473
)
474+
475+
# Initialize candidate processor
476+
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
398477
candidate_index = 0
399-
original_len = len(candidates)
400-
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
401-
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
402-
while True:
403-
try:
404-
if len(candidates) > 0:
405-
candidate = candidates.popleft()
406-
else:
407-
if not line_profiler_done:
408-
logger.debug("all candidates processed, await candidates from line profiler")
409-
concurrent.futures.wait([future_line_profile_results])
410-
line_profile_results = future_line_profile_results.result()
411-
candidates.extend(line_profile_results)
412-
original_len += len(line_profile_results)
413-
logger.info(
414-
f"Added results from line profiler to candidates, total candidates now: {original_len}"
415-
)
416-
line_profiler_done = True
417-
continue
418-
if line_profiler_done and not refinement_done:
419-
concurrent.futures.wait(future_all_refinements)
420-
refinement_response = []
421-
for future_refinement in future_all_refinements:
422-
possible_refinement = future_refinement.result()
423-
if len(possible_refinement) > 0: # if the api returns a valid response
424-
refinement_response.append(possible_refinement[0])
425-
candidates.extend(refinement_response)
426-
original_len += len(refinement_response)
427-
logger.info(
428-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
429-
)
430-
refinement_done = True
431-
continue
432-
if line_profiler_done and refinement_done:
433-
logger.debug("everything done, exiting")
434-
break
435478

479+
# Process candidates using queue-based approach
480+
while not processor.is_done():
481+
candidate = processor.get_next_candidate()
482+
if candidate is None:
483+
logger.debug("everything done, exiting")
484+
break
485+
486+
try:
436487
candidate_index += 1
437488
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
438489
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
439-
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
490+
logger.info(f"Optimization candidate {candidate_index}/{processor.candidate_len}:")
440491
code_print(candidate.source_code.flat)
441492
# map ast normalized code to diff len, unnormalized code
442493
# map opt id to the shortest unnormalized code
@@ -461,7 +512,7 @@ def determine_best_candidate(
461512
# check if this code has been evaluated before by checking the ast normalized code string
462513
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
463514
if normalized_code in ast_code_to_id:
464-
logger.warning(
515+
logger.info(
465516
"Current candidate has been encountered before in testing, Skipping optimization candidate."
466517
)
467518
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]

0 commit comments

Comments
 (0)