Skip to content

Commit 51c5e5a

Browse files
authored
Merge pull request #659 from codeflash-ai/candidate-loop-prevent-exit
Prevent candidate loop from exiting early
2 parents ea12bb3 + 05e9ab1 commit 51c5e5a

File tree

2 files changed

+46
-47
lines changed

2 files changed

+46
-47
lines changed

codeflash/api/aiservice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def optimize_python_code_line_profiler( # noqa: D417
202202

203203
if response.status_code == 200:
204204
optimizations_json = response.json()["optimizations"]
205-
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
205+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
206206
console.rule()
207207
return [
208208
OptimizedCandidate(

codeflash/optimization/function_optimizer.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def determine_best_candidate(
380380
console.rule()
381381
candidates = deque(candidates)
382382
refinement_done = False
383+
line_profiler_done = False
383384
future_all_refinements: list[concurrent.futures.Future] = []
384385
ast_code_to_id = {}
385386
valid_optimizations = []
@@ -400,19 +401,45 @@ def determine_best_candidate(
400401
if self.experiment_id
401402
else None,
402403
)
403-
try:
404-
candidate_index = 0
405-
original_len = len(candidates)
406-
while candidates:
404+
candidate_index = 0
405+
original_len = len(candidates)
406+
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407+
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408+
while True:
409+
try:
410+
if len(candidates) > 0:
411+
candidate = candidates.popleft()
412+
else:
413+
if not line_profiler_done:
414+
logger.debug("all candidates processed, await candidates from line profiler")
415+
concurrent.futures.wait([future_line_profile_results])
416+
line_profile_results = future_line_profile_results.result()
417+
candidates.extend(line_profile_results)
418+
original_len += len(line_profile_results)
419+
logger.info(
420+
f"Added results from line profiler to candidates, total candidates now: {original_len}"
421+
)
422+
line_profiler_done = True
423+
continue
424+
if line_profiler_done and not refinement_done:
425+
concurrent.futures.wait(future_all_refinements)
426+
refinement_response = []
427+
for future_refinement in future_all_refinements:
428+
possible_refinement = future_refinement.result()
429+
if len(possible_refinement) > 0: # if the api returns a valid response
430+
refinement_response.append(possible_refinement[0])
431+
candidates.extend(refinement_response)
432+
original_len += len(refinement_response)
433+
logger.info(
434+
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
435+
)
436+
refinement_done = True
437+
continue
438+
if line_profiler_done and refinement_done:
439+
logger.debug("everything done, exiting")
440+
break
441+
407442
candidate_index += 1
408-
line_profiler_done = True if future_line_profile_results is None else future_line_profile_results.done()
409-
if line_profiler_done and (future_line_profile_results is not None):
410-
line_profile_results = future_line_profile_results.result()
411-
candidates.extend(line_profile_results)
412-
original_len += len(line_profile_results)
413-
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
414-
future_line_profile_results = None
415-
candidate = candidates.popleft()
416443
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
417444
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
418445
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
@@ -474,7 +501,6 @@ def determine_best_candidate(
474501
file_path_to_helper_classes=file_path_to_helper_classes,
475502
)
476503
console.rule()
477-
478504
if not is_successful(run_results):
479505
optimized_runtimes[candidate.optimization_id] = None
480506
is_correct[candidate.optimization_id] = False
@@ -528,7 +554,6 @@ def determine_best_candidate(
528554
optimized_runtime_ns=candidate_replay_runtime,
529555
)
530556
benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%")
531-
532557
best_optimization = BestOptimization(
533558
candidate=candidate,
534559
helper_functions=code_context.helper_functions,
@@ -571,38 +596,12 @@ def determine_best_candidate(
571596
self.write_code_and_helpers(
572597
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
573598
)
574-
575-
if (
576-
(not len(candidates)) and (not line_profiler_done)
577-
): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
578-
concurrent.futures.wait([future_line_profile_results])
579-
line_profile_results = future_line_profile_results.result()
580-
candidates.extend(line_profile_results)
581-
original_len += len(line_profile_results)
582-
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
583-
future_line_profile_results = None
584-
# all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
585-
if (not len(candidates)) and line_profiler_done and not refinement_done:
586-
# waiting just in case not all calls are finished, nothing else to do
587-
concurrent.futures.wait(future_all_refinements)
588-
refinement_response = []
589-
for future_refinement in future_all_refinements:
590-
possible_refinement = future_refinement.result()
591-
if len(possible_refinement) > 0: # if the api returns a valid response
592-
refinement_response.append(possible_refinement[0])
593-
candidates.extend(refinement_response)
594-
original_len += len(refinement_response)
595-
logger.info(
596-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
597-
)
598-
refinement_done = True
599-
except KeyboardInterrupt as e:
600-
self.write_code_and_helpers(
601-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
602-
)
603-
logger.exception(f"Optimization interrupted: {e}")
604-
raise
605-
599+
except KeyboardInterrupt as e:
600+
self.write_code_and_helpers(
601+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
602+
)
603+
logger.exception(f"Optimization interrupted: {e}")
604+
raise
606605
if not valid_optimizations:
607606
return None
608607
# need to figure out the best candidate here before we return best_optimization

0 commit comments

Comments
 (0)