3
3
import ast
4
4
import concurrent .futures
5
5
import os
6
+ import queue
6
7
import random
7
8
import subprocess
8
9
import time
9
10
import uuid
10
- from collections import defaultdict , deque
11
+ from collections import defaultdict
11
12
from pathlib import Path
12
13
from typing import TYPE_CHECKING
13
14
104
105
from codeflash .verification .verification_utils import TestConfig
105
106
106
107
108
+ class CandidateProcessor :
109
+ """Handles candidate processing using a queue-based approach."""
110
+
111
+ def __init__ (
112
+ self ,
113
+ initial_candidates : list ,
114
+ future_line_profile_results : concurrent .futures .Future ,
115
+ future_all_refinements : list ,
116
+ ) -> None :
117
+ self .candidate_queue = queue .Queue ()
118
+ self .line_profiler_done = False
119
+ self .refinement_done = False
120
+ self .candidate_len = len (initial_candidates )
121
+
122
+ # Initialize queue with initial candidates
123
+ for candidate in initial_candidates :
124
+ self .candidate_queue .put (candidate )
125
+
126
+ self .future_line_profile_results = future_line_profile_results
127
+ self .future_all_refinements = future_all_refinements
128
+
129
+ def get_next_candidate (self ) -> OptimizedCandidate | None :
130
+ """Get the next candidate from the queue, handling async results as needed."""
131
+ try :
132
+ return self .candidate_queue .get_nowait ()
133
+ except queue .Empty :
134
+ return self ._handle_empty_queue ()
135
+
136
+ def _handle_empty_queue (self ) -> OptimizedCandidate | None :
137
+ """Handle empty queue by checking for pending async results."""
138
+ if not self .line_profiler_done :
139
+ return self ._process_line_profiler_results ()
140
+ if self .line_profiler_done and not self .refinement_done :
141
+ return self ._process_refinement_results ()
142
+ return None # All done
143
+
144
+ def _process_line_profiler_results (self ) -> OptimizedCandidate | None :
145
+ """Process line profiler results and add to queue."""
146
+ logger .debug ("all candidates processed, await candidates from line profiler" )
147
+ concurrent .futures .wait ([self .future_line_profile_results ])
148
+ line_profile_results = self .future_line_profile_results .result ()
149
+
150
+ for candidate in line_profile_results :
151
+ self .candidate_queue .put (candidate )
152
+
153
+ self .candidate_len += len (line_profile_results )
154
+ logger .info (f"Added results from line profiler to candidates, total candidates now: { self .candidate_len } " )
155
+ self .line_profiler_done = True
156
+
157
+ return self .get_next_candidate ()
158
+
159
+ def _process_refinement_results (self ) -> OptimizedCandidate | None :
160
+ """Process refinement results and add to queue."""
161
+ concurrent .futures .wait (self .future_all_refinements )
162
+ refinement_response = []
163
+
164
+ for future_refinement in self .future_all_refinements :
165
+ possible_refinement = future_refinement .result ()
166
+ if len (possible_refinement ) > 0 :
167
+ refinement_response .append (possible_refinement [0 ])
168
+
169
+ for candidate in refinement_response :
170
+ self .candidate_queue .put (candidate )
171
+
172
+ self .candidate_len += len (refinement_response )
173
+ logger .info (
174
+ f"Added { len (refinement_response )} candidates from refinement, total candidates now: { self .candidate_len } "
175
+ )
176
+ self .refinement_done = True
177
+
178
+ return self .get_next_candidate ()
179
+
180
+ def is_done (self ) -> bool :
181
+ """Check if processing is complete."""
182
+ return self .line_profiler_done and self .refinement_done and self .candidate_queue .empty ()
183
+
184
+
107
185
class FunctionOptimizer :
108
186
def __init__ (
109
187
self ,
@@ -378,15 +456,13 @@ def determine_best_candidate(
378
456
f"{ self .function_to_optimize .qualified_name } …"
379
457
)
380
458
console .rule ()
381
- candidates = deque (candidates )
382
- refinement_done = False
383
- line_profiler_done = False
459
+
384
460
future_all_refinements : list [concurrent .futures .Future ] = []
385
461
ast_code_to_id = {}
386
462
valid_optimizations = []
387
463
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
388
- # Start a new thread for AI service request, start loop in main thread
389
- # check if aiservice request is complete, when it is complete, append result to the candidates list
464
+
465
+ # Start a new thread for AI service request
390
466
ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
391
467
future_line_profile_results = self .executor .submit (
392
468
ai_service_client .optimize_python_code_line_profiler ,
@@ -401,48 +477,23 @@ def determine_best_candidate(
401
477
if self .experiment_id
402
478
else None ,
403
479
)
480
+
481
+ # Initialize candidate processor
482
+ processor = CandidateProcessor (candidates , future_line_profile_results , future_all_refinements )
404
483
candidate_index = 0
405
- original_len = len (candidates )
406
- # TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407
- # TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408
- while True :
409
- try :
410
- if len (candidates ) > 0 :
411
- candidate = candidates .popleft ()
412
- else :
413
- if not line_profiler_done :
414
- logger .debug ("all candidates processed, await candidates from line profiler" )
415
- concurrent .futures .wait ([future_line_profile_results ])
416
- line_profile_results = future_line_profile_results .result ()
417
- candidates .extend (line_profile_results )
418
- original_len += len (line_profile_results )
419
- logger .info (
420
- f"Added results from line profiler to candidates, total candidates now: { original_len } "
421
- )
422
- line_profiler_done = True
423
- continue
424
- if line_profiler_done and not refinement_done :
425
- concurrent .futures .wait (future_all_refinements )
426
- refinement_response = []
427
- for future_refinement in future_all_refinements :
428
- possible_refinement = future_refinement .result ()
429
- if len (possible_refinement ) > 0 : # if the api returns a valid response
430
- refinement_response .append (possible_refinement [0 ])
431
- candidates .extend (refinement_response )
432
- original_len += len (refinement_response )
433
- logger .info (
434
- f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
435
- )
436
- refinement_done = True
437
- continue
438
- if line_profiler_done and refinement_done :
439
- logger .debug ("everything done, exiting" )
440
- break
441
484
485
+ # Process candidates using queue-based approach
486
+ while not processor .is_done ():
487
+ candidate = processor .get_next_candidate ()
488
+ if candidate is None :
489
+ logger .debug ("everything done, exiting" )
490
+ break
491
+
492
+ try :
442
493
candidate_index += 1
443
494
get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
444
495
get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .sqlite" )).unlink (missing_ok = True )
445
- logger .info (f"Optimization candidate { candidate_index } /{ original_len } :" )
496
+ logger .info (f"Optimization candidate { candidate_index } /{ processor . candidate_len } :" )
446
497
code_print (candidate .source_code .flat )
447
498
# map ast normalized code to diff len, unnormalized code
448
499
# map opt id to the shortest unnormalized code
@@ -467,7 +518,7 @@ def determine_best_candidate(
467
518
# check if this code has been evaluated before by checking the ast normalized code string
468
519
normalized_code = ast .unparse (ast .parse (candidate .source_code .flat .strip ()))
469
520
if normalized_code in ast_code_to_id :
470
- logger .warning (
521
+ logger .info (
471
522
"Current candidate has been encountered before in testing, Skipping optimization candidate."
472
523
)
473
524
past_opt_id = ast_code_to_id [normalized_code ]["optimization_id" ]
@@ -1300,6 +1351,7 @@ def process_review(
1300
1351
return
1301
1352
1302
1353
def revert_code_and_helpers (self , original_helper_code : dict [Path , str ]) -> None :
1354
+ logger .info ("Reverting code and helpers..." )
1303
1355
self .write_code_and_helpers (
1304
1356
self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
1305
1357
)
0 commit comments