11from __future__ import annotations
22
3+ import concurrent .futures
34import json
45import os
56import platform
1213from codeflash .cli_cmds .console import console , logger
1314from codeflash .code_utils .code_replacer import is_zero_diff
1415from codeflash .code_utils .code_utils import unified_diff_strings
15- from codeflash .code_utils .config_consts import N_CANDIDATES_EFFECTIVE , N_CANDIDATES_LP_EFFECTIVE
1616from codeflash .code_utils .env_utils import get_codeflash_api_key
1717from codeflash .code_utils .git_utils import get_last_commit_author_if_pr_exists , get_repo_owner_and_name
1818from codeflash .code_utils .time_utils import humanize_runtime
3535 from codeflash .models .models import AIServiceCodeRepairRequest , AIServiceRefinerRequest
3636 from codeflash .result .explanation import Explanation
3737
38+ multi_model_executor = concurrent .futures .ThreadPoolExecutor (max_workers = 10 , thread_name_prefix = "multi_model" )
39+
3840
3941class AiServiceClient :
4042 def __init__ (self ) -> None :
@@ -92,7 +94,7 @@ def make_ai_service_request(
9294 return response
9395
9496 def _get_valid_candidates (
95- self , optimizations_json : list [dict [str , Any ]], source : OptimizedCandidateSource
97+ self , optimizations_json : list [dict [str , Any ]], source : OptimizedCandidateSource , model : str | None = None
9698 ) -> list [OptimizedCandidate ]:
9799 candidates : list [OptimizedCandidate ] = []
98100 for opt in optimizations_json :
@@ -106,6 +108,7 @@ def _get_valid_candidates(
106108 optimization_id = opt ["optimization_id" ],
107109 source = source ,
108110 parent_id = opt .get ("parent_id" , None ),
111+ model = model ,
109112 )
110113 )
111114 return candidates
@@ -119,6 +122,7 @@ def optimize_python_code( # noqa: D417
119122 experiment_metadata : ExperimentMetadata | None = None ,
120123 * ,
121124 is_async : bool = False ,
125+ model : str | None = None ,
122126 ) -> list [OptimizedCandidate ]:
123127 """Optimize the given python code for performance by making a request to the Django endpoint.
124128
@@ -129,6 +133,7 @@ def optimize_python_code( # noqa: D417
129133 - trace_id (str): Trace id of optimization run
130134 - num_candidates (int): Number of optimization variants to generate. Default is 10.
131135 - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
136+ - model (str | None): Model name to use ("gpt-4.1" or "claude-sonnet-4-5"). Default is None (server default).
132137
133138 Returns
134139 -------
@@ -149,8 +154,9 @@ def optimize_python_code( # noqa: D417
149154 "current_username" : get_last_commit_author_if_pr_exists (None ),
150155 "repo_owner" : git_repo_owner ,
151156 "repo_name" : git_repo_name ,
152- "n_candidates" : N_CANDIDATES_EFFECTIVE ,
157+ "n_candidates" : num_candidates ,
153158 "is_async" : is_async ,
159+ "model" : model ,
154160 }
155161
156162 logger .info ("!lsp|Generating optimized candidates…" )
@@ -167,7 +173,7 @@ def optimize_python_code( # noqa: D417
167173 console .rule ()
168174 end_time = time .perf_counter ()
169175 logger .debug (f"!lsp|Generating possible optimizations took { end_time - start_time :.2f} seconds." )
170- return self ._get_valid_candidates (optimizations_json , OptimizedCandidateSource .OPTIMIZE )
176+ return self ._get_valid_candidates (optimizations_json , OptimizedCandidateSource .OPTIMIZE , model = model )
171177 try :
172178 error = response .json ()["error" ]
173179 except Exception :
@@ -185,6 +191,7 @@ def optimize_python_code_line_profiler( # noqa: D417
185191 line_profiler_results : str ,
186192 num_candidates : int = 10 ,
187193 experiment_metadata : ExperimentMetadata | None = None ,
194+ model : str | None = None ,
188195 ) -> list [OptimizedCandidate ]:
189196 """Optimize the given python code for performance by making a request to the Django endpoint.
190197
@@ -195,6 +202,7 @@ def optimize_python_code_line_profiler( # noqa: D417
195202 - trace_id (str): Trace id of optimization run
196203 - num_candidates (int): Number of optimization variants to generate. Default is 10.
197204 - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
205+ - model (str | None): Model name to use ("gpt-4.1" or "claude-sonnet-4-5"). Default is None (server default).
198206
199207 Returns
200208 -------
@@ -211,7 +219,8 @@ def optimize_python_code_line_profiler( # noqa: D417
211219 "experiment_metadata" : experiment_metadata ,
212220 "codeflash_version" : codeflash_version ,
213221 "lsp_mode" : is_LSP_enabled (),
214- "n_candidates_lp" : N_CANDIDATES_LP_EFFECTIVE ,
222+ "n_candidates_lp" : num_candidates ,
223+ "model" : model ,
215224 }
216225
217226 console .rule ()
@@ -232,7 +241,7 @@ def optimize_python_code_line_profiler( # noqa: D417
232241 f"!lsp|Generated { len (optimizations_json )} candidate optimizations using line profiler information."
233242 )
234243 console .rule ()
235- return self ._get_valid_candidates (optimizations_json , OptimizedCandidateSource .OPTIMIZE_LP )
244+ return self ._get_valid_candidates (optimizations_json , OptimizedCandidateSource .OPTIMIZE_LP , model = model )
236245 try :
237246 error = response .json ()["error" ]
238247 except Exception :
@@ -242,6 +251,95 @@ def optimize_python_code_line_profiler( # noqa: D417
242251 console .rule ()
243252 return []
244253
254+ def optimize_python_code_multi_model (
255+ self ,
256+ source_code : str ,
257+ dependency_code : str ,
258+ base_trace_id : str ,
259+ model_distribution : list [tuple [str , int ]],
260+ experiment_metadata : ExperimentMetadata | None = None ,
261+ * ,
262+ is_async : bool = False ,
263+ ) -> list [OptimizedCandidate ]:
264+ """Generate optimizations using multiple models in parallel."""
265+ futures : list [tuple [concurrent .futures .Future [list [OptimizedCandidate ]], str ]] = []
266+ call_index = 0
267+
268+ for model_name , num_calls in model_distribution :
269+ for _ in range (num_calls ):
270+ call_trace_id = f"{ base_trace_id [:- 4 ]} M{ call_index :02d} "
271+ call_index += 1
272+
273+ future = multi_model_executor .submit (
274+ self .optimize_python_code ,
275+ source_code ,
276+ dependency_code ,
277+ call_trace_id ,
278+ num_candidates = 1 , # Each call returns 1 candidate
279+ experiment_metadata = experiment_metadata ,
280+ is_async = is_async ,
281+ model = model_name ,
282+ )
283+ futures .append ((future , model_name ))
284+
285+ # Wait for all calls to complete
286+ concurrent .futures .wait ([f for f , _ in futures ])
287+
288+ # Collect results
289+ all_candidates : list [OptimizedCandidate ] = []
290+ for future , model_name in futures :
291+ try :
292+ candidates = future .result ()
293+ all_candidates .extend (candidates )
294+ except Exception as e :
295+ logger .warning (f"Model { model_name } call failed: { e } " )
296+ continue
297+
298+ return all_candidates
299+
300+ def optimize_python_code_line_profiler_multi_model (
301+ self ,
302+ source_code : str ,
303+ dependency_code : str ,
304+ base_trace_id : str ,
305+ line_profiler_results : str ,
306+ model_distribution : list [tuple [str , int ]],
307+ experiment_metadata : ExperimentMetadata | None = None ,
308+ ) -> list [OptimizedCandidate ]:
309+ """Generate line profiler optimizations using multiple models in parallel."""
310+ futures : list [tuple [concurrent .futures .Future [list [OptimizedCandidate ]], str ]] = []
311+ call_index = 0
312+
313+ for model_name , num_calls in model_distribution :
314+ for _ in range (num_calls ):
315+ call_trace_id = f"{ base_trace_id [:- 4 ]} L{ call_index :02d} "
316+ call_index += 1
317+
318+ future = multi_model_executor .submit (
319+ self .optimize_python_code_line_profiler ,
320+ source_code ,
321+ dependency_code ,
322+ call_trace_id ,
323+ line_profiler_results ,
324+ num_candidates = 1 ,
325+ experiment_metadata = experiment_metadata ,
326+ model = model_name ,
327+ )
328+ futures .append ((future , model_name ))
329+
330+ concurrent .futures .wait ([f for f , _ in futures ])
331+
332+ all_candidates : list [OptimizedCandidate ] = []
333+ for future , model_name in futures :
334+ try :
335+ candidates = future .result ()
336+ all_candidates .extend (candidates )
337+ except Exception as e :
338+ logger .warning (f"Line profiler model { model_name } call failed: { e } " )
339+ continue
340+
341+ return all_candidates
342+
245343 def optimize_python_code_refinement (self , request : list [AIServiceRefinerRequest ]) -> list [OptimizedCandidate ]:
246344 """Optimize the given python code for performance by making a request to the Django endpoint.
247345
0 commit comments