33import ast
44import concurrent .futures
55import os
6+ import queue
67import random
78import subprocess
89import time
910import uuid
10- from collections import defaultdict , deque
11+ from collections import defaultdict
1112from pathlib import Path
1213from typing import TYPE_CHECKING
1314
103104 from codeflash .verification .verification_utils import TestConfig
104105
105106
107+ class CandidateProcessor :
108+ """Handles candidate processing using a queue-based approach."""
109+
110+ def __init__ (
111+ self ,
112+ initial_candidates : list ,
113+ future_line_profile_results : concurrent .futures .Future ,
114+ future_all_refinements : list ,
115+ ) -> None :
116+ self .candidate_queue = queue .Queue ()
117+ self .line_profiler_done = False
118+ self .refinement_done = False
119+ self .candidate_len = len (initial_candidates )
120+
121+ # Initialize queue with initial candidates
122+ for candidate in initial_candidates :
123+ self .candidate_queue .put (candidate )
124+
125+ self .future_line_profile_results = future_line_profile_results
126+ self .future_all_refinements = future_all_refinements
127+
128+ def get_next_candidate (self ) -> OptimizedCandidate | None :
129+ """Get the next candidate from the queue, handling async results as needed."""
130+ try :
131+ return self .candidate_queue .get_nowait ()
132+ except queue .Empty :
133+ return self ._handle_empty_queue ()
134+
135+ def _handle_empty_queue (self ) -> OptimizedCandidate | None :
136+ """Handle empty queue by checking for pending async results."""
137+ if not self .line_profiler_done :
138+ return self ._process_line_profiler_results ()
139+ if self .line_profiler_done and not self .refinement_done :
140+ return self ._process_refinement_results ()
141+ return None # All done
142+
143+ def _process_line_profiler_results (self ) -> OptimizedCandidate | None :
144+ """Process line profiler results and add to queue."""
145+ logger .debug ("all candidates processed, await candidates from line profiler" )
146+ concurrent .futures .wait ([self .future_line_profile_results ])
147+ line_profile_results = self .future_line_profile_results .result ()
148+
149+ for candidate in line_profile_results :
150+ self .candidate_queue .put (candidate )
151+
152+ self .candidate_len += len (line_profile_results )
153+ logger .info (f"Added results from line profiler to candidates, total candidates now: { self .candidate_len } " )
154+ self .line_profiler_done = True
155+
156+ return self .get_next_candidate ()
157+
158+ def _process_refinement_results (self ) -> OptimizedCandidate | None :
159+ """Process refinement results and add to queue."""
160+ concurrent .futures .wait (self .future_all_refinements )
161+ refinement_response = []
162+
163+ for future_refinement in self .future_all_refinements :
164+ possible_refinement = future_refinement .result ()
165+ if len (possible_refinement ) > 0 :
166+ refinement_response .append (possible_refinement [0 ])
167+
168+ for candidate in refinement_response :
169+ self .candidate_queue .put (candidate )
170+
171+ self .candidate_len += len (refinement_response )
172+ logger .info (
173+ f"Added { len (refinement_response )} candidates from refinement, total candidates now: { self .candidate_len } "
174+ )
175+ self .refinement_done = True
176+
177+ return self .get_next_candidate ()
178+
179+ def is_done (self ) -> bool :
180+ """Check if processing is complete."""
181+ return self .line_profiler_done and self .refinement_done and self .candidate_queue .empty ()
182+
183+
106184class FunctionOptimizer :
107185 def __init__ (
108186 self ,
@@ -372,15 +450,13 @@ def determine_best_candidate(
372450 f"{ self .function_to_optimize .qualified_name } …"
373451 )
374452 console .rule ()
375- candidates = deque (candidates )
376- refinement_done = False
377- line_profiler_done = False
453+
378454 future_all_refinements : list [concurrent .futures .Future ] = []
379455 ast_code_to_id = {}
380456 valid_optimizations = []
381457 optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
382- # Start a new thread for AI service request, start loop in main thread
383- # check if aiservice request is complete, when it is complete, append result to the candidates list
458+
459+ # Start a new thread for AI service request
384460 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
385461 future_line_profile_results = self .executor .submit (
386462 ai_service_client .optimize_python_code_line_profiler ,
@@ -395,48 +471,23 @@ def determine_best_candidate(
395471 if self .experiment_id
396472 else None ,
397473 )
474+
475+ # Initialize candidate processor
476+ processor = CandidateProcessor (candidates , future_line_profile_results , future_all_refinements )
398477 candidate_index = 0
399- original_len = len (candidates )
400- # TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
401- # TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
402- while True :
403- try :
404- if len (candidates ) > 0 :
405- candidate = candidates .popleft ()
406- else :
407- if not line_profiler_done :
408- logger .debug ("all candidates processed, await candidates from line profiler" )
409- concurrent .futures .wait ([future_line_profile_results ])
410- line_profile_results = future_line_profile_results .result ()
411- candidates .extend (line_profile_results )
412- original_len += len (line_profile_results )
413- logger .info (
414- f"Added results from line profiler to candidates, total candidates now: { original_len } "
415- )
416- line_profiler_done = True
417- continue
418- if line_profiler_done and not refinement_done :
419- concurrent .futures .wait (future_all_refinements )
420- refinement_response = []
421- for future_refinement in future_all_refinements :
422- possible_refinement = future_refinement .result ()
423- if len (possible_refinement ) > 0 : # if the api returns a valid response
424- refinement_response .append (possible_refinement [0 ])
425- candidates .extend (refinement_response )
426- original_len += len (refinement_response )
427- logger .info (
428- f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
429- )
430- refinement_done = True
431- continue
432- if line_profiler_done and refinement_done :
433- logger .debug ("everything done, exiting" )
434- break
435478
479+ # Process candidates using queue-based approach
480+ while not processor .is_done ():
481+ candidate = processor .get_next_candidate ()
482+ if candidate is None :
483+ logger .debug ("everything done, exiting" )
484+ break
485+
486+ try :
436487 candidate_index += 1
437488 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
438489 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .sqlite" )).unlink (missing_ok = True )
439- logger .info (f"Optimization candidate { candidate_index } /{ original_len } :" )
490+ logger .info (f"Optimization candidate { candidate_index } /{ processor . candidate_len } :" )
440491 code_print (candidate .source_code .flat )
441492 # map ast normalized code to diff len, unnormalized code
442493 # map opt id to the shortest unnormalized code
@@ -461,7 +512,7 @@ def determine_best_candidate(
461512 # check if this code has been evaluated before by checking the ast normalized code string
462513 normalized_code = ast .unparse (ast .parse (candidate .source_code .flat .strip ()))
463514 if normalized_code in ast_code_to_id :
464- logger .warning (
515+ logger .info (
465516 "Current candidate has been encountered before in testing, Skipping optimization candidate."
466517 )
467518 past_opt_id = ast_code_to_id [normalized_code ]["optimization_id" ]
0 commit comments