1212from codeflash .cli_cmds .console import console , logger
1313from codeflash .code_utils .env_utils import get_codeflash_api_key , is_LSP_enabled
1414from codeflash .code_utils .git_utils import get_last_commit_author_if_pr_exists , get_repo_owner_and_name
15- from codeflash .models .models import OptimizedCandidate
15+ from codeflash .models .ExperimentMetadata import ExperimentMetadata
16+ from codeflash .models .models import AIServiceRefinerRequest , OptimizedCandidate
1617from codeflash .telemetry .posthog_cf import ph
1718from codeflash .version import __version__ as codeflash_version
1819
2122
2223 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
2324 from codeflash .models .ExperimentMetadata import ExperimentMetadata
25+ from codeflash .models .models import AIServiceRefinerRequest
2426
2527
2628class AiServiceClient :
@@ -36,7 +38,11 @@ def get_aiservice_base_url(self) -> str:
3638 return "https://app.codeflash.ai"
3739
3840 def make_ai_service_request (
39- self , endpoint : str , method : str = "POST" , payload : dict [str , Any ] | None = None , timeout : float | None = None
41+ self ,
42+ endpoint : str ,
43+ method : str = "POST" ,
44+ payload : dict [str , Any ] | list [dict [str , Any ]] | None = None ,
45+ timeout : float | None = None ,
4046 ) -> requests .Response :
4147 """Make an API request to the given endpoint on the AI service.
4248
@@ -98,11 +104,7 @@ def optimize_python_code( # noqa: D417
98104
99105 """
100106 start_time = time .perf_counter ()
101- try :
102- git_repo_owner , git_repo_name = get_repo_owner_and_name ()
103- except Exception as e :
104- logger .warning (f"Could not determine repo owner and name: { e } " )
105- git_repo_owner , git_repo_name = None , None
107+ git_repo_owner , git_repo_name = safe_get_repo_owner_and_name ()
106108
107109 payload = {
108110 "source_code" : source_code ,
@@ -219,13 +221,145 @@ def optimize_python_code_line_profiler( # noqa: D417
219221 console .rule ()
220222 return []
221223
224+ def optimize_python_code_refinement (self , request : list [AIServiceRefinerRequest ]) -> list [OptimizedCandidate ]:
225+ """Optimize the given python code for performance by making a request to the Django endpoint.
226+
227+ Args:
228+ request: A list of optimization candidate details for refinement
229+
230+ Returns:
231+ -------
232+ - List[OptimizationCandidate]: A list of Optimization Candidates.
233+
234+ """
235+ payload = [
236+ {
237+ "optimization_id" : opt .optimization_id ,
238+ "original_source_code" : opt .original_source_code ,
239+ "read_only_dependency_code" : opt .read_only_dependency_code ,
240+ "original_line_profiler_results" : opt .original_line_profiler_results ,
241+ "original_code_runtime" : opt .original_code_runtime ,
242+ "optimized_source_code" : opt .optimized_source_code ,
243+ "optimized_explanation" : opt .optimized_explanation ,
244+ "optimized_line_profiler_results" : opt .optimized_line_profiler_results ,
245+ "optimized_code_runtime" : opt .optimized_code_runtime ,
246+ "speedup" : opt .speedup ,
247+ "trace_id" : opt .trace_id ,
248+ }
249+ for opt in request
250+ ]
251+ logger .info (f"Refining { len (request )} optimizations…" )
252+ console .rule ()
253+ try :
254+ response = self .make_ai_service_request ("/refinement" , payload = payload , timeout = 600 )
255+ except requests .exceptions .RequestException as e :
256+ logger .exception (f"Error generating optimization refinements: { e } " )
257+ ph ("cli-optimize-error-caught" , {"error" : str (e )})
258+ return []
259+
260+ if response .status_code == 200 :
261+ refined_optimizations = response .json ()["refinements" ]
262+ logger .info (f"Generated { len (refined_optimizations )} candidate refinements." )
263+ console .rule ()
264+ return [
265+ OptimizedCandidate (
266+ source_code = opt ["source_code" ],
267+ explanation = opt ["explanation" ],
268+ optimization_id = opt ["optimization_id" ][:- 4 ] + "refi" ,
269+ )
270+ for opt in refined_optimizations
271+ ]
272+ try :
273+ error = response .json ()["error" ]
274+ except Exception :
275+ error = response .text
276+ logger .error (f"Error generating optimized candidates: { response .status_code } - { error } " )
277+ ph ("cli-optimize-error-response" , {"response_status_code" : response .status_code , "error" : error })
278+ console .rule ()
279+ return []
280+
281+ def get_new_explanation ( # noqa: D417
282+ self ,
283+ source_code : str ,
284+ optimized_code : str ,
285+ dependency_code : str ,
286+ trace_id : str ,
287+ original_line_profiler_results : str ,
288+ optimized_line_profiler_results : str ,
289+ original_code_runtime : str ,
290+ optimized_code_runtime : str ,
291+ speedup : str ,
292+ annotated_tests : str ,
293+ optimization_id : str ,
294+ original_explanation : str ,
295+ ) -> str :
296+ """Optimize the given python code for performance by making a request to the Django endpoint.
297+
298+ Parameters
299+ ----------
300+ - source_code (str): The python code to optimize.
301+ - optimized_code (str): The python code generated by the AI service.
302+ - dependency_code (str): The dependency code used as read-only context for the optimization
303+ - original_line_profiler_results: str - line profiler results for the baseline code
304+ - optimized_line_profiler_results: str - line profiler results for the optimized code
305+ - original_code_runtime: str - runtime for the baseline code
306+ - optimized_code_runtime: str - runtime for the optimized code
307+ - speedup: str - speedup of the optimized code
308+ - annotated_tests: str - test functions annotated with runtime
309+ - optimization_id: str - unique id of opt candidate
310+ - original_explanation: str - original_explanation generated for the opt candidate
311+
312+ Returns
313+ -------
314+ - List[OptimizationCandidate]: A list of Optimization Candidates.
315+
316+ """
317+ payload = {
318+ "trace_id" : trace_id ,
319+ "source_code" : source_code ,
320+ "optimized_code" : optimized_code ,
321+ "original_line_profiler_results" : original_line_profiler_results ,
322+ "optimized_line_profiler_results" : optimized_line_profiler_results ,
323+ "original_code_runtime" : original_code_runtime ,
324+ "optimized_code_runtime" : optimized_code_runtime ,
325+ "speedup" : speedup ,
326+ "annotated_tests" : annotated_tests ,
327+ "optimization_id" : optimization_id ,
328+ "original_explanation" : original_explanation ,
329+ "dependency_code" : dependency_code ,
330+ }
331+ logger .info ("Generating explanation" )
332+ console .rule ()
333+ try :
334+ response = self .make_ai_service_request ("/explain" , payload = payload , timeout = 60 )
335+ except requests .exceptions .RequestException as e :
336+ logger .exception (f"Error generating explanations: { e } " )
337+ ph ("cli-optimize-error-caught" , {"error" : str (e )})
338+ return ""
339+
340+ if response .status_code == 200 :
341+ explanation : str = response .json ()["explanation" ]
342+ logger .debug (f"New Explanation: { explanation } " )
343+ console .rule ()
344+ return explanation
345+ try :
346+ error = response .json ()["error" ]
347+ except Exception :
348+ error = response .text
349+ logger .error (f"Error generating optimized candidates: { response .status_code } - { error } " )
350+ ph ("cli-optimize-error-response" , {"response_status_code" : response .status_code , "error" : error })
351+ console .rule ()
352+ return ""
353+
222354 def log_results ( # noqa: D417
223355 self ,
224356 function_trace_id : str ,
225357 speedup_ratio : dict [str , float | None ] | None ,
226358 original_runtime : float | None ,
227359 optimized_runtime : dict [str , float | None ] | None ,
228360 is_correct : dict [str , bool ] | None ,
361+ optimized_line_profiler_results : dict [str , str ] | None ,
362+ metadata : dict [str , Any ] | None ,
229363 ) -> None :
230364 """Log features to the database.
231365
@@ -236,6 +370,8 @@ def log_results( # noqa: D417
236370 - original_runtime (Optional[Dict[str, float]]): The original runtime.
237371 - optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
238372 - is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
373+ - optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
374+ - metadata: contains the best optimization id
239375
240376 """
241377 payload = {
@@ -245,6 +381,8 @@ def log_results( # noqa: D417
245381 "optimized_runtime" : optimized_runtime ,
246382 "is_correct" : is_correct ,
247383 "codeflash_version" : codeflash_version ,
384+ "optimized_line_profiler_results" : optimized_line_profiler_results ,
385+ "metadata" : metadata ,
248386 }
249387 try :
250388 self .make_ai_service_request ("/log_features" , payload = payload , timeout = 5 )
@@ -331,3 +469,12 @@ class LocalAiServiceClient(AiServiceClient):
331469 def get_aiservice_base_url (self ) -> str :
332470 """Get the base URL for the local AI service."""
333471 return "http://localhost:8000"
472+
473+
474+ def safe_get_repo_owner_and_name () -> tuple [str | None , str | None ]:
475+ try :
476+ git_repo_owner , git_repo_name = get_repo_owner_and_name ()
477+ except Exception as e :
478+ logger .warning (f"Could not determine repo owner and name: { e } " )
479+ git_repo_owner , git_repo_name = None , None
480+ return git_repo_owner , git_repo_name
0 commit comments