Skip to content

Commit 49fc884

Browse files
Merge branch 'main' into fix/global-assignments-after-imports
2 parents 99cb908 + 02b4d65 commit 49fc884

File tree

1 file changed

+95
-44
lines changed

1 file changed

+95
-44
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 95 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import ast
44
import concurrent.futures
55
import os
6+
import queue
67
import random
78
import subprocess
89
import time
910
import uuid
10-
from collections import defaultdict, deque
11+
from collections import defaultdict
1112
from pathlib import Path
1213
from typing import TYPE_CHECKING
1314

@@ -104,6 +105,83 @@
104105
from codeflash.verification.verification_utils import TestConfig
105106

106107

108+
class CandidateProcessor:
109+
"""Handles candidate processing using a queue-based approach."""
110+
111+
def __init__(
112+
self,
113+
initial_candidates: list,
114+
future_line_profile_results: concurrent.futures.Future,
115+
future_all_refinements: list,
116+
) -> None:
117+
self.candidate_queue = queue.Queue()
118+
self.line_profiler_done = False
119+
self.refinement_done = False
120+
self.candidate_len = len(initial_candidates)
121+
122+
# Initialize queue with initial candidates
123+
for candidate in initial_candidates:
124+
self.candidate_queue.put(candidate)
125+
126+
self.future_line_profile_results = future_line_profile_results
127+
self.future_all_refinements = future_all_refinements
128+
129+
def get_next_candidate(self) -> OptimizedCandidate | None:
130+
"""Get the next candidate from the queue, handling async results as needed."""
131+
try:
132+
return self.candidate_queue.get_nowait()
133+
except queue.Empty:
134+
return self._handle_empty_queue()
135+
136+
def _handle_empty_queue(self) -> OptimizedCandidate | None:
137+
"""Handle empty queue by checking for pending async results."""
138+
if not self.line_profiler_done:
139+
return self._process_line_profiler_results()
140+
if self.line_profiler_done and not self.refinement_done:
141+
return self._process_refinement_results()
142+
return None # All done
143+
144+
def _process_line_profiler_results(self) -> OptimizedCandidate | None:
145+
"""Process line profiler results and add to queue."""
146+
logger.debug("all candidates processed, await candidates from line profiler")
147+
concurrent.futures.wait([self.future_line_profile_results])
148+
line_profile_results = self.future_line_profile_results.result()
149+
150+
for candidate in line_profile_results:
151+
self.candidate_queue.put(candidate)
152+
153+
self.candidate_len += len(line_profile_results)
154+
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}")
155+
self.line_profiler_done = True
156+
157+
return self.get_next_candidate()
158+
159+
def _process_refinement_results(self) -> OptimizedCandidate | None:
160+
"""Process refinement results and add to queue."""
161+
concurrent.futures.wait(self.future_all_refinements)
162+
refinement_response = []
163+
164+
for future_refinement in self.future_all_refinements:
165+
possible_refinement = future_refinement.result()
166+
if len(possible_refinement) > 0:
167+
refinement_response.append(possible_refinement[0])
168+
169+
for candidate in refinement_response:
170+
self.candidate_queue.put(candidate)
171+
172+
self.candidate_len += len(refinement_response)
173+
logger.info(
174+
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
175+
)
176+
self.refinement_done = True
177+
178+
return self.get_next_candidate()
179+
180+
def is_done(self) -> bool:
181+
"""Check if processing is complete."""
182+
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()
183+
184+
107185
class FunctionOptimizer:
108186
def __init__(
109187
self,
@@ -378,15 +456,13 @@ def determine_best_candidate(
378456
f"{self.function_to_optimize.qualified_name}…"
379457
)
380458
console.rule()
381-
candidates = deque(candidates)
382-
refinement_done = False
383-
line_profiler_done = False
459+
384460
future_all_refinements: list[concurrent.futures.Future] = []
385461
ast_code_to_id = {}
386462
valid_optimizations = []
387463
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
388-
# Start a new thread for AI service request, start loop in main thread
389-
# check if aiservice request is complete, when it is complete, append result to the candidates list
464+
465+
# Start a new thread for AI service request
390466
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
391467
future_line_profile_results = self.executor.submit(
392468
ai_service_client.optimize_python_code_line_profiler,
@@ -401,48 +477,23 @@ def determine_best_candidate(
401477
if self.experiment_id
402478
else None,
403479
)
480+
481+
# Initialize candidate processor
482+
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
404483
candidate_index = 0
405-
original_len = len(candidates)
406-
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407-
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408-
while True:
409-
try:
410-
if len(candidates) > 0:
411-
candidate = candidates.popleft()
412-
else:
413-
if not line_profiler_done:
414-
logger.debug("all candidates processed, await candidates from line profiler")
415-
concurrent.futures.wait([future_line_profile_results])
416-
line_profile_results = future_line_profile_results.result()
417-
candidates.extend(line_profile_results)
418-
original_len += len(line_profile_results)
419-
logger.info(
420-
f"Added results from line profiler to candidates, total candidates now: {original_len}"
421-
)
422-
line_profiler_done = True
423-
continue
424-
if line_profiler_done and not refinement_done:
425-
concurrent.futures.wait(future_all_refinements)
426-
refinement_response = []
427-
for future_refinement in future_all_refinements:
428-
possible_refinement = future_refinement.result()
429-
if len(possible_refinement) > 0: # if the api returns a valid response
430-
refinement_response.append(possible_refinement[0])
431-
candidates.extend(refinement_response)
432-
original_len += len(refinement_response)
433-
logger.info(
434-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
435-
)
436-
refinement_done = True
437-
continue
438-
if line_profiler_done and refinement_done:
439-
logger.debug("everything done, exiting")
440-
break
441484

485+
# Process candidates using queue-based approach
486+
while not processor.is_done():
487+
candidate = processor.get_next_candidate()
488+
if candidate is None:
489+
logger.debug("everything done, exiting")
490+
break
491+
492+
try:
442493
candidate_index += 1
443494
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
444495
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
445-
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
496+
logger.info(f"Optimization candidate {candidate_index}/{processor.candidate_len}:")
446497
code_print(candidate.source_code.flat)
447498
# map ast normalized code to diff len, unnormalized code
448499
# map opt id to the shortest unnormalized code
@@ -467,7 +518,7 @@ def determine_best_candidate(
467518
# check if this code has been evaluated before by checking the ast normalized code string
468519
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
469520
if normalized_code in ast_code_to_id:
470-
logger.warning(
521+
logger.info(
471522
"Current candidate has been encountered before in testing, Skipping optimization candidate."
472523
)
473524
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]

0 commit comments

Comments
 (0)