@@ -20,18 +20,22 @@ class CepoConfig:
2020 bestofn_n : int # number of responses to be generated in best of n stage
2121 bestofn_temperature : float # temperature for verifier in best of n stage
2222 bestofn_max_tokens : int # maximum number of tokens for verifier in best of n stage
23- bestofn_rating_type : Literal ["absolute" , "pairwise" ] # type of rating in best of n stage
23+ bestofn_rating_type : Literal ["absolute" , "pairwise" , "majority" ] # type of rating in best of n stage
2424 planning_n : int # number of plans generated in planning stage
2525 planning_m : int # number of attempts to generate n plans in planning stage
2626 planning_temperature_step1 : float # temperature for generator in step 1 of planning stage
2727 planning_temperature_step2 : float # temperature for generator in step 2 of planning stage
28+ planning_temperature_direct_resp : float # temperature for generator after step 2 if planning fails and answer directly
2829 planning_temperature_step3 : float # temperature for generator in step 3 of planning stage
2930 planning_temperature_step4 : float # temperature for generator in step 4 of planning stage
3031 planning_max_tokens_step1 : int # maximum number of tokens in step 1 of planning stage
3132 planning_max_tokens_step2 : int # maximum number of tokens in step 2 of planning stage
33+ planning_max_tokens_direct_resp : float # maximum number of tokens after step 2 if planning fails and answer directly
3234 planning_max_tokens_step3 : int # maximum number of tokens in step 3 of planning stage
3335 planning_max_tokens_step4 : int # maximum number of tokens in step 4 of planning stage
3436 use_plan_diversity : bool # whether to use plan diversity
37+ use_reasoning_fallback : bool # whether to fallback to lower levels of reasoning when higher level fails
38+ num_of_retries : int # number of retries if llm call fails, 0 for no retries
3539 rating_model : Optional [str ] = None # model to be used for rating
3640 print_output : bool = False # whether to print the output of each stage
3741
@@ -203,6 +207,7 @@ def extract_llm_response(response):
203207def llm_call (
204208 client : Any ,
205209 provider_request : dict ,
210+ cepo_config : CepoConfig
206211) -> tuple [str , str , int ]:
207212 """
208213 Call the LLM with retries on transient errors.
@@ -220,7 +225,7 @@ def llm_call(
220225 - finish_reason: Why generation stopped.
221226 - completion_tokens: Number of tokens generated.
222227 """
223- retries = 2 # total attempts = retries + 1 initial call
228+ retries = cepo_config . num_of_retries # total attempts = retries + 1 initial call
224229 for attempt in range (retries ):
225230 try :
226231 response_object = client .chat .completions .create (
@@ -247,7 +252,8 @@ def llm_call(
247252def llm_call_reason_effort_fallback (
248253 client : Any ,
249254 provider_request : dict ,
250- reasoning_effort_levels : list
255+ reasoning_effort_levels : list ,
256+ cepo_config : CepoConfig
251257) -> tuple [Optional [Any ], str , int ]:
252258 """
253259 Call LLM with fallback on reasoning effort levels.
@@ -291,13 +297,16 @@ def llm_call_reason_effort_fallback(
291297 automatically, but a permanent fix may require upstream changes
292298 (see https://github.com/pydantic/pydantic-ai/issues/2449).
293299 """
300+ if not cepo_config .use_reasoning_fallback :
301+ reasoning_effort_levels = ["high" ]
294302 for effort in reasoning_effort_levels :
295303 try :
296304 # Try with the current reasoning effort level
297305 provider_request ["reasoning_effort" ] = effort
298306 response , finish_reason , completion_tokens = llm_call (
299307 client = client ,
300308 provider_request = provider_request ,
309+ cepo_config = cepo_config
301310 )
302311 if response is not None and finish_reason != "length" :
303312 return response , finish_reason , completion_tokens
@@ -310,27 +319,6 @@ def llm_call_reason_effort_fallback(
310319 return None , "error" , 0
311320
312321
313- def fallback_direct_answer (client , model , question , max_tokens = None , temperature = 1.0 , top_p = 1.0 ): # TODO clean-up
314- messages = [
315- {"role" : "user" , "content" : question },
316- ]
317-
318- response , finish_reason , completion_tokens = llm_call_reason_effort_fallback (
319- messages = messages ,
320- client = client ,
321- model = model ,
322- max_tokens = max_tokens ,
323- temperature = temperature ,
324- top_p = top_p ,
325- reasoning_effort_levels = ["high" , "medium" , "low" ]
326- )
327- if response is None or finish_reason == "length" :
328- print ("Direct answer failed, empty response or length" )
329- response = ""
330- messages .append ({"role" : "assistant" , "content" : response })
331- return response , messages
332-
333-
334322def generate_completion (system_prompt : str , task : str , client : Any , model : str , cepo_config : CepoConfig , approach : Optional [str ] = None , request_id : str = None ) -> str :
335323 """
336324 Generates a completion based on the provided system prompt and task.
@@ -385,7 +373,8 @@ def generate_single_plan(i):
385373 response , finish_reason , completion_tokens = llm_call_reason_effort_fallback (
386374 client = client ,
387375 provider_request = provider_request ,
388- reasoning_effort_levels = ["high" , "medium" ]
376+ reasoning_effort_levels = ["high" , "medium" ],
377+ cepo_config = cepo_config
389378 )
390379 local_completion_tokens += completion_tokens
391380 # Log provider call if conversation logging is enabled
@@ -418,7 +407,8 @@ def generate_single_plan(i):
418407 response , finish_reason , completion_tokens = llm_call_reason_effort_fallback (
419408 client = client ,
420409 provider_request = provider_request ,
421- reasoning_effort_levels = ["high" , "medium" ]
410+ reasoning_effort_levels = ["high" , "medium" ],
411+ cepo_config = cepo_config
422412 )
423413 local_completion_tokens += completion_tokens
424414
@@ -453,10 +443,39 @@ def generate_single_plan(i):
453443 plans = [plan for _ , plan in sorted (plans )] # keep original order
454444
455445 if not plans :
456- # Fallback plan
457- fallback_generation , fallback_messages = fallback_direct_answer (client , model , question_only )
458- plans .append (fallback_generation )
459- cb_log [f"messages_planning_fallback_used" ] = fallback_messages
446+ # If no plans were generated, attempt to answer directly
447+ messages = [
448+ {"role" : "user" , "content" : question_only },
449+ ]
450+
451+ provider_request = {
452+ "model" : model ,
453+ "messages" : messages ,
454+ "max_tokens" : cepo_config .planning_max_tokens_step2_direct ,
455+ "temperature" :cepo_config .planning_temperature_step2_direct ,
456+ "top_p" : 0.95 ,
457+ "reasoning_effort_levels" : ["high" , "medium" , "low" ]
458+ }
459+
460+ response , finish_reason , completion_tokens = llm_call_reason_effort_fallback (
461+ client = client ,
462+ provider_request = provider_request ,
463+ cepo_config = cepo_config
464+ )
465+ local_completion_tokens += completion_tokens
466+
467+ # Log provider call if conversation logging is enabled
468+ if hasattr (optillm , 'conversation_logger' ) and optillm .conversation_logger and request_id :
469+ response_dict = response .model_dump () if hasattr (response , 'model_dump' ) else response
470+ optillm .conversation_logger .log_provider_call (request_id , provider_request , response_dict )
471+
472+ if response is None or finish_reason == "length" :
473+ print ("Direct answer failed, empty response or length" )
474+ response = ""
475+ messages .append ({"role" : "assistant" , "content" : response })
476+
477+ plans .append (response )
478+ cb_log [f"messages_planning_fallback_used" ] = messages
460479 if cepo_config .print_output :
461480 print (f"\n CePO: No plans generated successfully. Taking the fallback.\n " )
462481
@@ -483,7 +502,8 @@ def generate_single_plan(i):
483502 response , finish_reason , completion_tokens_ = llm_call_reason_effort_fallback (
484503 client = client ,
485504 provider_request = provider_request ,
486- reasoning_effort_levels = ["high" , "medium" ]
505+ reasoning_effort_levels = ["high" , "medium" ],
506+ cepo_config = cepo_config
487507 )
488508 completion_tokens += completion_tokens_
489509
@@ -519,7 +539,8 @@ def generate_single_plan(i):
519539 response , finish_reason , completion_tokens_ = llm_call_reason_effort_fallback (
520540 client = client ,
521541 provider_request = provider_request ,
522- reasoning_effort_levels = ["high" , "medium" ]
542+ reasoning_effort_levels = ["high" , "medium" ],
543+ cepo_config = cepo_config
523544 )
524545 completion_tokens += completion_tokens_
525546
@@ -718,7 +739,8 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
718739 rating_response , _ , completion_tokens = llm_call_reason_effort_fallback (
719740 client = client ,
720741 provider_request = provider_request ,
721- reasoning_effort_levels = ["high" , "medium" ]
742+ reasoning_effort_levels = ["high" , "medium" ],
743+ cepo_config = cepo_config
722744 )
723745
724746 # Log provider call if conversation logging is enabled
@@ -906,7 +928,7 @@ def majority_vote_mcq(completions, last_n_chars=100):
906928 return response , count
907929
908930
909- def rate_completions_majority_vote (completions : list [str ], last_n_chars : int = 150 ) -> tuple [str , int , dict ]:
931+ def rate_completions_majority (completions : list [str ], last_n_chars : int = 150 ) -> tuple [str , int , dict ]:
910932 mcq_majority , count = majority_vote_mcq (completions , last_n_chars )
911933 if mcq_majority is None :
912934 return majority_vote_math (completions , last_n_chars )
@@ -948,8 +970,8 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c
948970 best_completion , completion_tokens_rating , cb_log = rate_completions_absolute (system_prompt , initial_query , client , rating_model , completions , cepo_config , cb_log , request_id )
949971 elif cepo_config .bestofn_rating_type == "pairwise" :
950972 best_completion , completion_tokens_rating , cb_log = rate_completions_pairwise (system_prompt , initial_query , client , rating_model , completions , cepo_config , cb_log , request_id )
951- elif cepo_config .bestofn_rating_type == "majority_with_code_exec " :
952- best_completion , _ = rate_completions_majority_vote (completions )
973+ elif cepo_config .bestofn_rating_type == "majority " :
974+ best_completion , _ = rate_completions_majority (completions )
953975 completion_tokens_rating = 0
954976 else :
955977 raise ValueError ("Invalid rating type in cepo_config" )
0 commit comments