1212from codeflash .cli_cmds .console import console , logger
1313from codeflash .code_utils .env_utils import get_codeflash_api_key , is_LSP_enabled
1414from codeflash .code_utils .git_utils import get_last_commit_author_if_pr_exists , get_repo_owner_and_name
15- from codeflash .models .models import OptimizedCandidate
15+ from codeflash .models .ExperimentMetadata import ExperimentMetadata
16+ from codeflash .models .models import AIServiceRefinerRequest , OptimizedCandidate
1617from codeflash .telemetry .posthog_cf import ph
1718from codeflash .version import __version__ as codeflash_version
1819
2122
2223 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
2324 from codeflash .models .ExperimentMetadata import ExperimentMetadata
25+ from codeflash .models .models import AIServiceRefinerRequest
2426
2527
2628class AiServiceClient :
@@ -36,7 +38,11 @@ def get_aiservice_base_url(self) -> str:
3638 return "https://app.codeflash.ai"
3739
3840 def make_ai_service_request (
39- self , endpoint : str , method : str = "POST" , payload : dict [str , Any ] | None = None , timeout : float | None = None
41+ self ,
42+ endpoint : str ,
43+ method : str = "POST" ,
44+ payload : dict [str , Any ] | list [dict [str , Any ]] | None = None ,
45+ timeout : float | None = None ,
4046 ) -> requests .Response :
4147 """Make an API request to the given endpoint on the AI service.
4248
@@ -98,11 +104,7 @@ def optimize_python_code( # noqa: D417
98104
99105 """
100106 start_time = time .perf_counter ()
101- try :
102- git_repo_owner , git_repo_name = get_repo_owner_and_name ()
103- except Exception as e :
104- logger .warning (f"Could not determine repo owner and name: { e } " )
105- git_repo_owner , git_repo_name = None , None
107+ git_repo_owner , git_repo_name = safe_get_repo_owner_and_name ()
106108
107109 payload = {
108110 "source_code" : source_code ,
@@ -219,13 +221,72 @@ def optimize_python_code_line_profiler( # noqa: D417
219221 console .rule ()
220222 return []
221223
224+ def optimize_python_code_refinement (self , request : list [AIServiceRefinerRequest ]) -> list [OptimizedCandidate ]:
225+ """Optimize the given python code for performance by making a request to the Django endpoint.
226+
227+ Args:
228+ request: A list of optimization candidate details for refinement
229+
230+ Returns:
231+ -------
232+ - List[OptimizationCandidate]: A list of Optimization Candidates.
233+
234+ """
235+ payload = [
236+ {
237+ "optimization_id" : opt .optimization_id ,
238+ "original_source_code" : opt .original_source_code ,
239+ "read_only_dependency_code" : opt .read_only_dependency_code ,
240+ "original_line_profiler_results" : opt .original_line_profiler_results ,
241+ "original_code_runtime" : opt .original_code_runtime ,
242+ "optimized_source_code" : opt .optimized_source_code ,
243+ "optimized_explanation" : opt .optimized_explanation ,
244+ "optimized_line_profiler_results" : opt .optimized_line_profiler_results ,
245+ "optimized_code_runtime" : opt .optimized_code_runtime ,
246+ "speedup" : opt .speedup ,
247+ "trace_id" : opt .trace_id ,
248+ }
249+ for opt in request
250+ ]
251+ logger .info (f"Refining { len (request )} optimizations…" )
252+ console .rule ()
253+ try :
254+ response = self .make_ai_service_request ("/refinement" , payload = payload , timeout = 600 )
255+ except requests .exceptions .RequestException as e :
256+ logger .exception (f"Error generating optimization refinements: { e } " )
257+ ph ("cli-optimize-error-caught" , {"error" : str (e )})
258+ return []
259+
260+ if response .status_code == 200 :
261+ refined_optimizations = response .json ()["refinements" ]
262+ logger .info (f"Generated { len (refined_optimizations )} candidate refinements." )
263+ console .rule ()
264+ return [
265+ OptimizedCandidate (
266+ source_code = opt ["source_code" ],
267+ explanation = opt ["explanation" ],
268+ optimization_id = opt ["optimization_id" ][:- 4 ] + "refi" ,
269+ )
270+ for opt in refined_optimizations
271+ ]
272+ try :
273+ error = response .json ()["error" ]
274+ except Exception :
275+ error = response .text
276+ logger .error (f"Error generating optimized candidates: { response .status_code } - { error } " )
277+ ph ("cli-optimize-error-response" , {"response_status_code" : response .status_code , "error" : error })
278+ console .rule ()
279+ return []
280+
222281 def log_results ( # noqa: D417
223282 self ,
224283 function_trace_id : str ,
225284 speedup_ratio : dict [str , float | None ] | None ,
226285 original_runtime : float | None ,
227286 optimized_runtime : dict [str , float | None ] | None ,
228287 is_correct : dict [str , bool ] | None ,
288+ optimized_line_profiler_results : dict [str , str ] | None ,
289+ metadata : dict [str , Any ] | None ,
229290 ) -> None :
230291 """Log features to the database.
231292
@@ -236,6 +297,8 @@ def log_results( # noqa: D417
236297 - original_runtime (Optional[Dict[str, float]]): The original runtime.
237298 - optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
238299 - is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
300+ - optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
301+ - metadata: contains the best optimization id
239302
240303 """
241304 payload = {
@@ -245,6 +308,8 @@ def log_results( # noqa: D417
245308 "optimized_runtime" : optimized_runtime ,
246309 "is_correct" : is_correct ,
247310 "codeflash_version" : codeflash_version ,
311+ "optimized_line_profiler_results" : optimized_line_profiler_results ,
312+ "metadata" : metadata ,
248313 }
249314 try :
250315 self .make_ai_service_request ("/log_features" , payload = payload , timeout = 5 )
@@ -331,3 +396,12 @@ class LocalAiServiceClient(AiServiceClient):
331396 def get_aiservice_base_url (self ) -> str :
332397 """Get the base URL for the local AI service."""
333398 return "http://localhost:8000"
399+
400+
401+ def safe_get_repo_owner_and_name () -> tuple [str | None , str | None ]:
402+ try :
403+ git_repo_owner , git_repo_name = get_repo_owner_and_name ()
404+ except Exception as e :
405+ logger .warning (f"Could not determine repo owner and name: { e } " )
406+ git_repo_owner , git_repo_name = None , None
407+ return git_repo_owner , git_repo_name
0 commit comments