Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def optimize_python_code_line_profiler( # noqa: D417

if response.status_code == 200:
optimizations_json = response.json()["optimizations"]
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
console.rule()
return [
OptimizedCandidate(
Expand Down
91 changes: 45 additions & 46 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def determine_best_candidate(
console.rule()
candidates = deque(candidates)
refinement_done = False
line_profiler_done = False
future_all_refinements: list[concurrent.futures.Future] = []
ast_code_to_id = {}
valid_optimizations = []
Expand All @@ -400,19 +401,45 @@ def determine_best_candidate(
if self.experiment_id
else None,
)
try:
candidate_index = 0
original_len = len(candidates)
while candidates:
candidate_index = 0
original_len = len(candidates)
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
while True:
try:
if len(candidates) > 0:
candidate = candidates.popleft()
else:
if not line_profiler_done:
logger.debug("all candidates processed, await candidates from line profiler")
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
line_profiler_done = True
continue
if line_profiler_done and not refinement_done:
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
original_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
)
refinement_done = True
continue
if line_profiler_done and refinement_done:
logger.debug("everything done, exiting")
break

candidate_index += 1
line_profiler_done = True if future_line_profile_results is None else future_line_profile_results.done()
if line_profiler_done and (future_line_profile_results is not None):
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
candidate = candidates.popleft()
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
Expand Down Expand Up @@ -474,7 +501,6 @@ def determine_best_candidate(
file_path_to_helper_classes=file_path_to_helper_classes,
)
console.rule()

if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
Expand Down Expand Up @@ -528,7 +554,6 @@ def determine_best_candidate(
optimized_runtime_ns=candidate_replay_runtime,
)
benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%")

best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
Expand Down Expand Up @@ -571,38 +596,12 @@ def determine_best_candidate(
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)

if (
(not len(candidates)) and (not line_profiler_done)
): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
# all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
if (not len(candidates)) and line_profiler_done and not refinement_done:
# waiting just in case not all calls are finished, nothing else to do
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
original_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
)
refinement_done = True
except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
logger.exception(f"Optimization interrupted: {e}")
raise

except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
logger.exception(f"Optimization interrupted: {e}")
raise
if not valid_optimizations:
return None
# need to figure out the best candidate here before we return best_optimization
Expand Down
Loading