Skip to content

Commit 79604cd

Browse files
Add majority vote rating
1 parent 68ac32a commit 79604cd

File tree

1 file changed

+157
-120
lines changed

1 file changed

+157
-120
lines changed

optillm/cepo/cepo.py

Lines changed: 157 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import json
44
import optillm
55
import time
6+
import math_verify
67

78
from optillm import conversation_logger
9+
from collections import Counter
810
from concurrent.futures import ThreadPoolExecutor, as_completed
911
from dataclasses import dataclass
1012
from typing import Literal, Any, Optional
@@ -34,6 +36,82 @@ class CepoConfig:
3436
print_output: bool = False # whether to print the output of each stage
3537

3638

39+
MCQ_PATTERNS = [
40+
# 0)"**Answer:** A" or "*Answers* – B", i.e. markdown‐wrapped "Answer(s)" with an unwrapped letter.
41+
re.compile(
42+
r'''(?ix) # case‐insensitive, ignore‐space
43+
(?:\*{1,2}|_{1,2}) # leading *…* or _…_
44+
Answer[s]? # Answer or Answers
45+
\s*[:\-–]? # optional separator
46+
(?:\*{1,2}|_{1,2}) # closing wrapper
47+
\s* # optional space
48+
([ABCD])\b # the actual letter
49+
''',
50+
re.X
51+
),
52+
53+
# 0.1)
54+
re.compile(r'''(?ix) # ignore case, allow verbose mode
55+
^\s* # optional leading whitespace
56+
(?:\*{1,2}|_{1,2})? # optional markdown wrapper
57+
Answer:? # the word 'answer' with an optional colon
58+
(?:\*{1,2}|_{1,2})? # optional markdown wrapper again
59+
\s*:?\s* # optional colon with optional spaces
60+
(?:\*{1,2}|_{1,2})? # optional markdown wrapper before letter
61+
([ABCD]) # capture the letter
62+
(?:\*{1,2}|_{1,2})? # optional markdown wrapper after letter
63+
\s* # optional trailing whitespace, end of line
64+
''', re.MULTILINE),
65+
66+
# 1) Answer: (C) or Answers: (B)
67+
re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)'),
68+
69+
# 2) Answer: C or Answers – D
70+
re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b'),
71+
72+
# 3) Option B or Choice: C
73+
re.compile(r'(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b'),
74+
75+
# 7) LaTeX \boxed{...A...}, catches both \boxed{A} and
76+
# \boxed{\text{A } 2.08\times10^{-6}\,\mathrm{m}} etc.
77+
re.compile(r'(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}', re.MULTILINE),
78+
79+
# 7.5) LaTeX \boxed{\textbf{...C...}}
80+
re.compile(r'(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
81+
82+
# 7.51) LaTeX \boxed{\text{...C...}}
83+
re.compile(r'(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
84+
85+
# 4) bare singletons: (A) [B]
86+
re.compile(r'(?x)(?<![A-Za-z0-9])[\(\[]\s*([ABCD])\s*[\)\]](?![A-Za-z0-9])'),
87+
88+
# 5) Markdown‐wrapped: *A* **B** _C_ __D__
89+
re.compile(r'(?x)(?<![A-Za-z0-9])(?:\*{1,2}|_{1,2})([ABCD])(?:\*{1,2}|_{1,2})(?![A-Za-z0-9])'),
90+
91+
# 6) LaTeX \textbf{...C...}
92+
re.compile(r'(?x)\\textbf\{[^}]*?([ABCD])[^}]*\}'),
93+
94+
# 8) markdown‐wrapped answer plus “)” plus description, e.g. **D) …**
95+
re.compile(r'''(?x) # ignore whitespace in pattern
96+
(?<![A-Za-z0-9]) # not preceded by word‐char
97+
(?:\*{1,2}|_{1,2}) # opening ** or __ or * or _
98+
\s*([ABCD])\) # capture letter plus “)”
99+
[^*_\n]+? # some text inside wrapper
100+
(?:\*{1,2}|_{1,2}) # closing wrapper
101+
(?![A-Za-z0-9]) # not followed by word‐char
102+
'''),
103+
104+
# 9) final fallback: a line that's exactly "A", "B.", "C)", "**D**", etc.
105+
re.compile(r'''(?x)^\s*
106+
(?:\*{1,2}|_{1,2})? # optional markdown wrapper
107+
([ABCD]) # capture group for letter
108+
(?:\*{1,2}|_{1,2})? # optional closing markdown
109+
\s*[\.\)\-–:]? # optional separator after the letter
110+
\s*.*$ # allow any following text
111+
''', re.MULTILINE),
112+
]
113+
114+
37115
# given command line arguments which includes a yaml file path, initialize a CePO configuration
38116
def init_cepo_config(cmd_line_args: dict) -> CepoConfig:
39117
# get the command line arguments
@@ -463,121 +541,6 @@ def generate_single_plan(i):
463541
print(f"\nCePO: Answer generated for one bestofn_n attempt.")
464542

465543
return final_output, completion_tokens, cb_log
466-
467-
468-
# Log provider call if conversation logging is enabled
469-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
470-
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
471-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
472-
completion_tokens += response.usage.completion_tokens
473-
474-
if response.choices[0].finish_reason == "length":
475-
# Skipping plan generation due to exceeding the token budget. Usually it means the plan is incomplete.
476-
continue
477-
478-
# Step 2 - Execute the plan
479-
content = f"Can you execute the above plan step-by-step to produce the final answer. "\
480-
f"Be extra careful when executing steps where your confidence is lower."
481-
messages.extend([{"role": "assistant", "content": response.choices[0].message.content}, {"role": "user", "content": content}])
482-
483-
# Prepare request for logging
484-
provider_request = {
485-
"model": model,
486-
"messages": messages,
487-
"max_tokens": cepo_config.planning_max_tokens_step2,
488-
"temperature": cepo_config.planning_temperature_step2,
489-
"stream": False,
490-
}
491-
492-
response = client.chat.completions.create(**provider_request)
493-
494-
# Log provider call if conversation logging is enabled
495-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
496-
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
497-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
498-
completion_tokens += response.usage.completion_tokens
499-
500-
if response.choices[0].finish_reason == "length":
501-
messages.append({"role": "assistant", "content": response.choices[0].message.content})
502-
cb_log[f"messages_planning_{i}_rejected_due_to_length"] = messages
503-
if cepo_config.print_output:
504-
print(f"\nCePO: Plan proposal rejected due to length. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}")
505-
continue
506-
507-
plans.append(response.choices[0].message.content)
508-
messages.append({"role": "assistant", "content": response.choices[0].message.content})
509-
cb_log[f"messages_planning_{i}"] = messages
510-
if cepo_config.print_output:
511-
print(f"\nCePO: Plan proposal generated. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}")
512-
513-
if len(plans) == cepo_config.planning_n:
514-
break
515-
516-
if not plans:
517-
# If no plans were generated succesfully, take the last one even if it was rejected due to length
518-
plans.append(response.choices[0].message.content)
519-
messages.append({"role": "assistant", "content": response.choices[0].message.content})
520-
cb_log[f"messages_planning_{i}_no_plans_so_taking_the_last_one"] = messages
521-
if cepo_config.print_output:
522-
print(f"\nCePO: No plans generated successfully. Taking the last one from rejected due to length.\nMessages: {messages}")
523-
524-
# Step 3 - Review and address inconsistencies
525-
try:
526-
plans_message = ""
527-
for i, plan in enumerate(plans):
528-
plans_message += f"Response {i + 1}:\n{plan}\n\n"
529-
plans_message = plans_message[:-2] # remove the last 2x newline
530-
content = f"Can you review your last {len(plans)} responses and identify any inconsistency between them. After that, can you address "\
531-
f"it and present a final step-by-step solution to the problem? Here is the question:\n{question_only}"
532-
messages = [{"role": "assistant", "content": plans_message}, {"role": "user", "content": content}]
533-
534-
# Prepare request for logging
535-
provider_request = {
536-
"model": model,
537-
"messages": messages,
538-
"max_tokens": cepo_config.planning_max_tokens_step3,
539-
"temperature": cepo_config.planning_temperature_step3,
540-
"stream": False,
541-
}
542-
543-
response = client.chat.completions.create(**provider_request)
544-
545-
# Log provider call if conversation logging is enabled
546-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
547-
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
548-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
549-
final_solution = response.choices[0].message.content
550-
completion_tokens += response.usage.completion_tokens
551-
except (CerebrasBadRequestError, OpenAIBadRequestError) as e:
552-
# In case of an error, take the first plan as the final solution
553-
final_solution = plans[0]
554-
messages = []
555-
556-
# Step 4 - Answer the question
557-
content = f"Use your final solution from above to correctly answer the question. Here is the question:\n{task}"
558-
messages = [{"role": "assistant", "content": final_solution}, {"role": "user", "content": content}]
559-
560-
# Prepare request for logging
561-
provider_request = {
562-
"model": model,
563-
"messages": messages,
564-
"max_tokens": cepo_config.planning_max_tokens_step4,
565-
"temperature": cepo_config.planning_temperature_step4,
566-
"stream": False,
567-
}
568-
569-
response = client.chat.completions.create(**provider_request)
570-
571-
# Log provider call if conversation logging is enabled
572-
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id:
573-
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
574-
optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict)
575-
completion_tokens += response.usage.completion_tokens
576-
577-
cb_log["messages"] = messages
578-
if cepo_config.print_output:
579-
print(f"\nCePO: Answer generated.\nMessages: {messages}")
580-
return response.choices[0].message.content, completion_tokens, cb_log
581544

582545

583546
def generate_approaches(system_prompt: str, initial_query: str, num_approach: int, client: Any, model: str, cepo_config: CepoConfig, max_retry: int = 2, request_id: str = None) -> tuple[list[str], int]:
@@ -877,6 +840,79 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An
877840
return completions[best_index], completion_tokens, cb_log
878841

879842

843+
def extract_answer_mathverify(response_str, last_n_chars=100):
844+
response_str = str(response_str)
845+
try:
846+
float(response_str)
847+
return [float(response_str)]
848+
except:
849+
response_str = response_str.split("</think>", 1)[1] if "</think>" in response_str else response_str
850+
if last_n_chars is not None:
851+
response_str = response_str[-last_n_chars:]
852+
parsed_result = math_verify.parse(response_str, parsing_timeout=None)
853+
return parsed_result
854+
855+
856+
def extract_abcd(text: str) -> str | None:
857+
"""
858+
Scan text (with Markdown/LaTeX wrappers intact) and return
859+
'A', 'B', 'C', or 'D' if a correct-answer declaration is found.
860+
Otherwise return None.
861+
"""
862+
matches = []
863+
for prio, pat in enumerate(MCQ_PATTERNS):
864+
m = pat.search(text)
865+
if m:
866+
letter = m.group(1).upper()
867+
if letter in 'ABCD':
868+
matches.append((prio, m, letter))
869+
870+
matches.sort(key=lambda triple: (
871+
triple[0],
872+
len(triple[1].group(0))
873+
))
874+
for _, match, letter in matches:
875+
return letter
876+
return text.removeprefix('**')[:1]
877+
878+
879+
def majority_vote_math(completions, last_n_chars=100):
880+
extracted_answer_map = []
881+
for response in completions:
882+
extracted_answer = extract_answer_mathverify(response, last_n_chars)
883+
extracted_answer = extracted_answer[0] if extracted_answer else None
884+
extracted_answer_map.append((response, extracted_answer))
885+
886+
counts = Counter(answer for _, answer in extracted_answer_map)
887+
majority_answer, count = counts.most_common(1)[0]
888+
# TODO it may return all "None", we probably should handle this case
889+
# Return one response whose extracted answer matches the majority
890+
for response, answer in extracted_answer_map:
891+
if answer == majority_answer:
892+
return response, count
893+
894+
895+
def majority_vote_mcq(completions, last_n_chars=100):
896+
extracted_answer_map = []
897+
for response in completions:
898+
extracted_answer = extract_abcd(response[-last_n_chars:])
899+
extracted_answer_map.append((response, extracted_answer))
900+
901+
counts = Counter(answer for _, answer in extracted_answer_map)
902+
majority_answer, count = counts.most_common(1)[0]
903+
# TODO it may return all "None", we probably should handle this case
904+
for response, answer in extracted_answer_map:
905+
if answer == majority_answer:
906+
return response, count
907+
908+
909+
def rate_completions_majority_vote(completions: list[str], last_n_chars: int = 150) -> tuple[str, int, dict]:
910+
mcq_majority, count = majority_vote_mcq(completions, last_n_chars)
911+
if mcq_majority is None:
912+
return majority_vote_math(completions, last_n_chars)
913+
return mcq_majority, count
914+
915+
880916
def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_config: CepoConfig, request_id: str = None) -> tuple[str, int]:
881917
"""
882918
Applies CePO reasoning flow for the given task. First, it generates multiple completions, and then rates them to select the best one.
@@ -907,14 +943,15 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c
907943
completions = [c for c in completions if c] # safeguard in case completion is None (observed with GPT OSS)
908944

909945
# Rate the completions
946+
rating_model = cepo_config.rating_model if cepo_config.rating_model else model
910947
if cepo_config.bestofn_rating_type == "absolute":
911-
rate_completions_fn = rate_completions_absolute
948+
best_completion, completion_tokens_rating, cb_log = rate_completions_absolute(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id)
912949
elif cepo_config.bestofn_rating_type == "pairwise":
913-
rate_completions_fn = rate_completions_pairwise
950+
best_completion, completion_tokens_rating, cb_log = rate_completions_pairwise(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id)
951+
elif cepo_config.bestofn_rating_type == "majority_with_code_exec":
952+
best_completion, _ = rate_completions_majority_vote(completions)
953+
completion_tokens_rating = 0
914954
else:
915955
raise ValueError("Invalid rating type in cepo_config")
916-
rating_model = cepo_config.rating_model if cepo_config.rating_model else model
917-
918-
best_completion, completion_tokens_rating, cb_log = rate_completions_fn(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id)
919956

920957
return best_completion, completion_tokens_planning + completion_tokens_rating

0 commit comments

Comments
 (0)