diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 1a46f91c5..176d928c5 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -202,7 +202,7 @@ def optimize_python_code_line_profiler( # noqa: D417 if response.status_code == 200: optimizations_json = response.json()["optimizations"] - logger.info(f"Generated {len(optimizations_json)} candidate optimizations.") + logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.") console.rule() return [ OptimizedCandidate( diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 610569aa3..320fc1d77 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -380,6 +380,7 @@ def determine_best_candidate( console.rule() candidates = deque(candidates) refinement_done = False + line_profiler_done = False future_all_refinements: list[concurrent.futures.Future] = [] ast_code_to_id = {} valid_optimizations = [] @@ -400,19 +401,45 @@ def determine_best_candidate( if self.experiment_id else None, ) - try: - candidate_index = 0 - original_len = len(candidates) - while candidates: + candidate_index = 0 + original_len = len(candidates) + # TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls, + # TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably + while True: + try: + if len(candidates) > 0: + candidate = candidates.popleft() + else: + if not line_profiler_done: + logger.debug("all candidates processed, await candidates from line profiler") + concurrent.futures.wait([future_line_profile_results]) + line_profile_results = future_line_profile_results.result() + candidates.extend(line_profile_results) + original_len += len(line_profile_results) + logger.info( + f"Added results from line profiler to candidates, total candidates now: {original_len}" + ) + line_profiler_done = True + continue + if line_profiler_done and not refinement_done: + concurrent.futures.wait(future_all_refinements) + refinement_response = [] + for future_refinement in future_all_refinements: + possible_refinement = future_refinement.result() + if len(possible_refinement) > 0: # if the api returns a valid response + refinement_response.append(possible_refinement[0]) + candidates.extend(refinement_response) + original_len += len(refinement_response) + logger.info( + f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}" + ) + refinement_done = True + continue + if line_profiler_done and refinement_done: + logger.debug("everything done, exiting") + break + candidate_index += 1 - line_profiler_done = True if future_line_profile_results is None else future_line_profile_results.done() - if line_profiler_done and (future_line_profile_results is not None): - line_profile_results = future_line_profile_results.result() - candidates.extend(line_profile_results) - original_len += len(line_profile_results) - logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}") - future_line_profile_results = None - candidate = candidates.popleft() get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"Optimization candidate {candidate_index}/{original_len}:") @@ -474,7 +501,6 @@ def determine_best_candidate( file_path_to_helper_classes=file_path_to_helper_classes, ) console.rule() - if not is_successful(run_results): optimized_runtimes[candidate.optimization_id] = None is_correct[candidate.optimization_id] = False @@ -528,7 +554,6 @@ def determine_best_candidate( optimized_runtime_ns=candidate_replay_runtime, ) benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%") - best_optimization = BestOptimization( candidate=candidate, helper_functions=code_context.helper_functions, @@ -571,38 +596,12 @@ def determine_best_candidate( self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) - - if ( - (not len(candidates)) and (not line_profiler_done) - ): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not - concurrent.futures.wait([future_line_profile_results]) - line_profile_results = future_line_profile_results.result() - candidates.extend(line_profile_results) - original_len += len(line_profile_results) - logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}") - future_line_profile_results = None - # all original candidates and lp candidates processed, collect refinement candidates and append to candidate list - if (not len(candidates)) and line_profiler_done and not refinement_done: - # waiting just in case not all calls are finished, nothing else to do - concurrent.futures.wait(future_all_refinements) - refinement_response = [] - for future_refinement in future_all_refinements: - possible_refinement = future_refinement.result() - if len(possible_refinement) > 0: # if the api returns a valid response - refinement_response.append(possible_refinement[0]) - candidates.extend(refinement_response) - original_len += len(refinement_response) - logger.info( - f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}" - ) - refinement_done = True - except KeyboardInterrupt as e: - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - logger.exception(f"Optimization interrupted: {e}") - raise - + except KeyboardInterrupt as e: + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + logger.exception(f"Optimization interrupted: {e}") + raise if not valid_optimizations: return None # need to figure out the best candidate here before we return best_optimization