Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def optimize_python_code_line_profiler( # noqa: D417

if response.status_code == 200:
optimizations_json = response.json()["optimizations"]
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
console.rule()
return [
OptimizedCandidate(
Expand Down
89 changes: 43 additions & 46 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def determine_best_candidate(
console.rule()
candidates = deque(candidates)
refinement_done = False
line_profiler_done = False
future_all_refinements: list[concurrent.futures.Future] = []
ast_code_to_id = {}
valid_optimizations = []
Expand All @@ -400,19 +401,43 @@ def determine_best_candidate(
if self.experiment_id
else None,
)
try:
candidate_index = 0
original_len = len(candidates)
while candidates:
candidate_index = 0
original_len = len(candidates)
while True:
try:
if len(candidates) > 0:
candidate = candidates.popleft()
else:
if not line_profiler_done:
logger.debug("all candidates processed, await candidates from line profiler")
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
line_profiler_done = True
continue
if line_profiler_done and not refinement_done:
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
original_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
)
refinement_done = True
continue
if line_profiler_done and refinement_done:
logger.debug("everything done, exiting")
break

candidate_index += 1
line_profiler_done = True if future_line_profile_results is None else future_line_profile_results.done()
if line_profiler_done and (future_line_profile_results is not None):
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
candidate = candidates.popleft()
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
Expand Down Expand Up @@ -474,7 +499,6 @@ def determine_best_candidate(
file_path_to_helper_classes=file_path_to_helper_classes,
)
console.rule()

if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
Expand Down Expand Up @@ -528,7 +552,6 @@ def determine_best_candidate(
optimized_runtime_ns=candidate_replay_runtime,
)
benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%")

best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
Expand Down Expand Up @@ -571,38 +594,12 @@ def determine_best_candidate(
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)

if (
(not len(candidates)) and (not line_profiler_done)
): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
# all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
if (not len(candidates)) and line_profiler_done and not refinement_done:
# waiting just in case not all calls are finished, nothing else to do
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
original_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
)
refinement_done = True
except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
logger.exception(f"Optimization interrupted: {e}")
raise

except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
logger.exception(f"Optimization interrupted: {e}")
raise
if not valid_optimizations:
return None
# need to figure out the best candidate here before we return best_optimization
Expand Down
Loading