Skip to content

Commit f35d534

Browse files
Integrate qwen cepo flow
1 parent 79604cd commit f35d534

File tree

4 files changed

+104
-37
lines changed

4 files changed

+104
-37
lines changed

optillm/cepo/cepo.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ class CepoConfig:
2020
bestofn_n: int # number of responses to be generated in best of n stage
2121
bestofn_temperature: float # temperature for verifier in best of n stage
2222
bestofn_max_tokens: int # maximum number of tokens for verifier in best of n stage
23-
bestofn_rating_type: Literal["absolute", "pairwise"] # type of rating in best of n stage
23+
bestofn_rating_type: Literal["absolute", "pairwise", "majority"] # type of rating in best of n stage
2424
planning_n: int # number of plans generated in planning stage
2525
planning_m: int # number of attempts to generate n plans in planning stage
2626
planning_temperature_step1: float # temperature for generator in step 1 of planning stage
2727
planning_temperature_step2: float # temperature for generator in step 2 of planning stage
28+
planning_temperature_direct_resp: float # temperature for generator after step 2 if planning fails and answer directly
2829
planning_temperature_step3: float # temperature for generator in step 3 of planning stage
2930
planning_temperature_step4: float # temperature for generator in step 4 of planning stage
3031
planning_max_tokens_step1: int # maximum number of tokens in step 1 of planning stage
3132
planning_max_tokens_step2: int # maximum number of tokens in step 2 of planning stage
33+
planning_max_tokens_direct_resp: float # maximum number of tokens after step 2 if planning fails and answer directly
3234
planning_max_tokens_step3: int # maximum number of tokens in step 3 of planning stage
3335
planning_max_tokens_step4: int # maximum number of tokens in step 4 of planning stage
3436
use_plan_diversity: bool # whether to use plan diversity
37+
use_reasoning_fallback: bool # whether to fallback to lower levels of reasoning when higher level fails
38+
num_of_retries: int # number of retries if llm call fails, 0 for no retries
3539
rating_model: Optional[str] = None # model to be used for rating
3640
print_output: bool = False # whether to print the output of each stage
3741

@@ -203,6 +207,7 @@ def extract_llm_response(response):
203207
def llm_call(
204208
client: Any,
205209
provider_request: dict,
210+
cepo_config: CepoConfig
206211
) -> tuple[str, str, int]:
207212
"""
208213
Call the LLM with retries on transient errors.
@@ -220,7 +225,7 @@ def llm_call(
220225
- finish_reason: Why generation stopped.
221226
- completion_tokens: Number of tokens generated.
222227
"""
223-
retries = 2 # total attempts = retries + 1 initial call
228+
retries = cepo_config.num_of_retries # total attempts = retries + 1 initial call
224229
for attempt in range(retries):
225230
try:
226231
response_object = client.chat.completions.create(
@@ -247,7 +252,8 @@ def llm_call(
247252
def llm_call_reason_effort_fallback(
248253
client: Any,
249254
provider_request: dict,
250-
reasoning_effort_levels: list
255+
reasoning_effort_levels: list,
256+
cepo_config: CepoConfig
251257
) -> tuple[Optional[Any], str, int]:
252258
"""
253259
Call LLM with fallback on reasoning effort levels.
@@ -291,13 +297,16 @@ def llm_call_reason_effort_fallback(
291297
automatically, but a permanent fix may require upstream changes
292298
(see https://github.com/pydantic/pydantic-ai/issues/2449).
293299
"""
300+
if not cepo_config.use_reasoning_fallback:
301+
reasoning_effort_levels = ["high"]
294302
for effort in reasoning_effort_levels:
295303
try:
296304
# Try with the current reasoning effort level
297305
provider_request["reasoning_effort"] = effort
298306
response, finish_reason, completion_tokens = llm_call(
299307
client=client,
300308
provider_request=provider_request,
309+
cepo_config=cepo_config
301310
)
302311
if response is not None and finish_reason != "length":
303312
return response, finish_reason, completion_tokens
@@ -310,27 +319,6 @@ def llm_call_reason_effort_fallback(
310319
return None, "error", 0
311320

312321

313-
def fallback_direct_answer(client, model, question, max_tokens=None, temperature=1.0, top_p=1.0): # TODO clean-up
314-
messages = [
315-
{"role": "user", "content": question},
316-
]
317-
318-
response, finish_reason, completion_tokens = llm_call_reason_effort_fallback(
319-
messages=messages,
320-
client=client,
321-
model=model,
322-
max_tokens=max_tokens,
323-
temperature=temperature,
324-
top_p=top_p,
325-
reasoning_effort_levels=["high", "medium", "low"]
326-
)
327-
if response is None or finish_reason == "length":
328-
print("Direct answer failed, empty response or length")
329-
response = ""
330-
messages.append({"role": "assistant", "content": response})
331-
return response, messages
332-
333-
334322
def generate_completion(system_prompt: str, task: str, client: Any, model: str, cepo_config: CepoConfig, approach: Optional[str] = None, request_id: str = None) -> str:
335323
"""
336324
Generates a completion based on the provided system prompt and task.
@@ -385,7 +373,8 @@ def generate_single_plan(i):
385373
response, finish_reason, completion_tokens = llm_call_reason_effort_fallback(
386374
client=client,
387375
provider_request=provider_request,
388-
reasoning_effort_levels=["high", "medium"]
376+
reasoning_effort_levels=["high", "medium"],
377+
cepo_config=cepo_config
389378
)
390379
local_completion_tokens += completion_tokens
391380
# Log provider call if conversation logging is enabled
@@ -418,7 +407,8 @@ def generate_single_plan(i):
418407
response, finish_reason, completion_tokens = llm_call_reason_effort_fallback(
419408
client=client,
420409
provider_request=provider_request,
421-
reasoning_effort_levels=["high", "medium"]
410+
reasoning_effort_levels=["high", "medium"],
411+
cepo_config=cepo_config
422412
)
423413
local_completion_tokens += completion_tokens
424414

@@ -453,10 +443,39 @@ def generate_single_plan(i):
453443
plans = [plan for _, plan in sorted(plans)] # keep original order
454444

455445
if not plans:
456-
# Fallback plan
457-
fallback_generation, fallback_messages = fallback_direct_answer(client, model, question_only)
458-
plans.append(fallback_generation)
459-
cb_log[f"messages_planning_fallback_used"] = fallback_messages
446+
# If no plans were generated, attempt to answer directly
447+
messages = [
448+
{"role": "user", "content": question_only},
449+
]
450+
451+
provider_request = {
452+
"model": model,
453+
"messages": messages,
454+
"max_tokens": cepo_config.planning_max_tokens_step2_direct,
455+
"temperature":cepo_config.planning_temperature_step2_direct,
456+
"top_p": 0.95,
457+
"reasoning_effort_levels": ["high", "medium", "low"]
458+
}
459+
460+
response, finish_reason, completion_tokens = llm_call_reason_effort_fallback(
461+
client=client,
462+
provider_request=provider_request,
463+
cepo_config=cepo_config
464+
)
465+
local_completion_tokens += completion_tokens
466+
467+
# Log provider call if conversation logging is enabled
468+
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
469+
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
470+
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
471+
472+
if response is None or finish_reason == "length":
473+
print("Direct answer failed, empty response or length")
474+
response = ""
475+
messages.append({"role": "assistant", "content": response})
476+
477+
plans.append(response)
478+
cb_log[f"messages_planning_fallback_used"] = messages
460479
if cepo_config.print_output:
461480
print(f"\nCePO: No plans generated successfully. Taking the fallback.\n")
462481

@@ -483,7 +502,8 @@ def generate_single_plan(i):
483502
response, finish_reason, completion_tokens_ = llm_call_reason_effort_fallback(
484503
client=client,
485504
provider_request=provider_request,
486-
reasoning_effort_levels=["high", "medium"]
505+
reasoning_effort_levels=["high", "medium"],
506+
cepo_config=cepo_config
487507
)
488508
completion_tokens += completion_tokens_
489509

@@ -519,7 +539,8 @@ def generate_single_plan(i):
519539
response, finish_reason, completion_tokens_ = llm_call_reason_effort_fallback(
520540
client=client,
521541
provider_request=provider_request,
522-
reasoning_effort_levels=["high", "medium"]
542+
reasoning_effort_levels=["high", "medium"],
543+
cepo_config=cepo_config
523544
)
524545
completion_tokens += completion_tokens_
525546

@@ -718,7 +739,8 @@ def rate_completions_absolute(system_prompt: str, initial_query: str, client: An
718739
rating_response, _, completion_tokens = llm_call_reason_effort_fallback(
719740
client=client,
720741
provider_request=provider_request,
721-
reasoning_effort_levels=["high", "medium"]
742+
reasoning_effort_levels=["high", "medium"],
743+
cepo_config=cepo_config
722744
)
723745

724746
# Log provider call if conversation logging is enabled
@@ -906,7 +928,7 @@ def majority_vote_mcq(completions, last_n_chars=100):
906928
return response, count
907929

908930

909-
def rate_completions_majority_vote(completions: list[str], last_n_chars: int = 150) -> tuple[str, int, dict]:
931+
def rate_completions_majority(completions: list[str], last_n_chars: int = 150) -> tuple[str, int, dict]:
910932
mcq_majority, count = majority_vote_mcq(completions, last_n_chars)
911933
if mcq_majority is None:
912934
return majority_vote_math(completions, last_n_chars)
@@ -948,8 +970,8 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c
948970
best_completion, completion_tokens_rating, cb_log = rate_completions_absolute(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id)
949971
elif cepo_config.bestofn_rating_type == "pairwise":
950972
best_completion, completion_tokens_rating, cb_log = rate_completions_pairwise(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id)
951-
elif cepo_config.bestofn_rating_type == "majority_with_code_exec":
952-
best_completion, _ = rate_completions_majority_vote(completions)
973+
elif cepo_config.bestofn_rating_type == "majority":
974+
best_completion, _ = rate_completions_majority(completions)
953975
completion_tokens_rating = 0
954976
else:
955977
raise ValueError("Invalid rating type in cepo_config")
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
bestofn_n: 3
22
bestofn_temperature: 0.1
33
bestofn_max_tokens: 4096
4-
bestofn_rating_type: "absolute" # or "pairwise"
4+
bestofn_rating_type: "absolute" # or "pairwise", "majority"
55
planning_n: 3
66
planning_m: 6
77
planning_temperature_step1: 0.55
88
planning_temperature_step2: 0.25
9+
planning_temperature_direct_resp: 0.1
910
planning_temperature_step3: 0.1
1011
planning_temperature_step4: 0
1112
planning_max_tokens_step1: 4096
1213
planning_max_tokens_step2: 4096
14+
planning_max_tokens_direct_resp: 4096
1315
planning_max_tokens_step3: 4096
1416
planning_max_tokens_step4: 4096
1517
use_plan_diversity: False
1618
rating_model: null
19+
use_reasoning_effort_fallback: False
1720
print_output: False
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
bestofn_n: 1
2+
bestofn_temperature: 0.6
3+
bestofn_max_tokens: 40960
4+
bestofn_rating_type: "absolute"
5+
planning_n: 2
6+
planning_m: 4
7+
planning_temperature_step1: 1.0
8+
planning_temperature_step2: 1.0
9+
planning_temperature_direct_resp: 0.6
10+
planning_temperature_step3: 1.0
11+
planning_temperature_step4: 0.5
12+
planning_max_tokens_step1: 40960
13+
planning_max_tokens_step2: 40960
14+
planning_max_tokens_direct_resp: 32768
15+
planning_max_tokens_step3: 40960
16+
planning_max_tokens_step4: 40960
17+
use_plan_diversity: False
18+
rating_model: null
19+
use_reasoning_fallback: True
20+
num_of_retries: 2
21+
print_output: true
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
bestofn_n: 3
2+
bestofn_temperature: 0.6
3+
bestofn_max_tokens: 20480
4+
bestofn_rating_type: "majority"
5+
planning_n: 2
6+
planning_m: 4
7+
planning_temperature_step1: 0.8
8+
planning_temperature_step2: 0.8
9+
planning_temperature_direct_resp: 0.6
10+
planning_temperature_step3: 0.8
11+
planning_temperature_step4: 0.8
12+
planning_max_tokens_step1: 28672
13+
planning_max_tokens_step2: 24576
14+
planning_max_tokens_direct_resp: 32768
15+
planning_max_tokens_step3: 20481
16+
planning_max_tokens_step4: 20482
17+
use_plan_diversity: False
18+
rating_model: null
19+
use_reasoning_fallback: False
20+
num_of_retries: 0
21+
print_output: False

0 commit comments

Comments
 (0)