Skip to content

Commit 7fc1621

Browse files
committed
simplified loop
1 parent 1a9ee7f commit 7fc1621

File tree

2 files changed

+44
-47
lines changed

2 files changed

+44
-47
lines changed

codeflash/api/aiservice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def optimize_python_code_line_profiler( # noqa: D417
202202

203203
if response.status_code == 200:
204204
optimizations_json = response.json()["optimizations"]
205-
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
205+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
206206
console.rule()
207207
return [
208208
OptimizedCandidate(

codeflash/optimization/function_optimizer.py

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ def determine_best_candidate(
380380
console.rule()
381381
candidates = deque(candidates)
382382
refinement_done = False
383+
refinements_added = False
384+
line_profiler_done = False
385+
line_profiler_added = False
383386
future_all_refinements: list[concurrent.futures.Future] = []
384387
ast_code_to_id = {}
385388
valid_optimizations = []
@@ -400,19 +403,41 @@ def determine_best_candidate(
400403
if self.experiment_id
401404
else None,
402405
)
403-
try:
404-
candidate_index = 0
405-
original_len = len(candidates)
406-
while candidates:
406+
candidate_index = 0
407+
original_len = len(candidates)
408+
while True:
409+
try:
410+
if len(candidates)>0:
411+
candidate = candidates.popleft()
412+
else:
413+
if not line_profiler_done:
414+
logger.debug("all candidates processed, await candidates from line profiler")
415+
concurrent.futures.wait([future_line_profile_results])
416+
line_profile_results = future_line_profile_results.result()
417+
candidates.extend(line_profile_results)
418+
original_len += len(line_profile_results)
419+
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
420+
line_profiler_done = True
421+
continue
422+
if line_profiler_done and not refinement_done:
423+
concurrent.futures.wait(future_all_refinements)
424+
refinement_response = []
425+
for future_refinement in future_all_refinements:
426+
possible_refinement = future_refinement.result()
427+
if len(possible_refinement) > 0: # if the api returns a valid response
428+
refinement_response.append(possible_refinement[0])
429+
candidates.extend(refinement_response)
430+
original_len += len(refinement_response)
431+
logger.info(
432+
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
433+
)
434+
refinement_done = True
435+
continue
436+
if line_profiler_done and refinement_done:
437+
logger.debug("everything done, exiting")
438+
break
439+
407440
candidate_index += 1
408-
line_profiler_done = True if future_line_profile_results is None else future_line_profile_results.done()
409-
if line_profiler_done and (future_line_profile_results is not None):
410-
line_profile_results = future_line_profile_results.result()
411-
candidates.extend(line_profile_results)
412-
original_len += len(line_profile_results)
413-
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
414-
future_line_profile_results = None
415-
candidate = candidates.popleft()
416441
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
417442
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
418443
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
@@ -474,7 +499,6 @@ def determine_best_candidate(
474499
file_path_to_helper_classes=file_path_to_helper_classes,
475500
)
476501
console.rule()
477-
478502
if not is_successful(run_results):
479503
optimized_runtimes[candidate.optimization_id] = None
480504
is_correct[candidate.optimization_id] = False
@@ -528,7 +552,6 @@ def determine_best_candidate(
528552
optimized_runtime_ns=candidate_replay_runtime,
529553
)
530554
benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%")
531-
532555
best_optimization = BestOptimization(
533556
candidate=candidate,
534557
helper_functions=code_context.helper_functions,
@@ -571,38 +594,12 @@ def determine_best_candidate(
571594
self.write_code_and_helpers(
572595
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
573596
)
574-
575-
if (
576-
(not len(candidates)) and (not line_profiler_done)
577-
): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
578-
concurrent.futures.wait([future_line_profile_results])
579-
line_profile_results = future_line_profile_results.result()
580-
candidates.extend(line_profile_results)
581-
original_len += len(line_profile_results)
582-
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
583-
future_line_profile_results = None
584-
# all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
585-
if (not len(candidates)) and line_profiler_done and not refinement_done:
586-
# waiting just in case not all calls are finished, nothing else to do
587-
concurrent.futures.wait(future_all_refinements)
588-
refinement_response = []
589-
for future_refinement in future_all_refinements:
590-
possible_refinement = future_refinement.result()
591-
if len(possible_refinement) > 0: # if the api returns a valid response
592-
refinement_response.append(possible_refinement[0])
593-
candidates.extend(refinement_response)
594-
original_len += len(refinement_response)
595-
logger.info(
596-
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
597-
)
598-
refinement_done = True
599-
except KeyboardInterrupt as e:
600-
self.write_code_and_helpers(
601-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
602-
)
603-
logger.exception(f"Optimization interrupted: {e}")
604-
raise
605-
597+
except KeyboardInterrupt as e:
598+
self.write_code_and_helpers(
599+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
600+
)
601+
logger.exception(f"Optimization interrupted: {e}")
602+
raise
606603
if not valid_optimizations:
607604
return None
608605
# need to figure out the best candidate here before we return best_optimization

0 commit comments

Comments
 (0)