|
3 | 3 | import json |
4 | 4 | import optillm |
5 | 5 | import time |
| 6 | +import math_verify |
6 | 7 |
|
7 | 8 | from optillm import conversation_logger |
| 9 | +from collections import Counter |
8 | 10 | from concurrent.futures import ThreadPoolExecutor, as_completed |
9 | 11 | from dataclasses import dataclass |
10 | 12 | from typing import Literal, Any, Optional |
@@ -34,6 +36,82 @@ class CepoConfig: |
34 | 36 | print_output: bool = False # whether to print the output of each stage |
35 | 37 |
|
36 | 38 |
|
| 39 | +MCQ_PATTERNS = [ |
| 40 | + # 0)"**Answer:** A" or "*Answers* – B", i.e. markdown‐wrapped "Answer(s)" with an unwrapped letter. |
| 41 | + re.compile( |
| 42 | + r'''(?ix) # case‐insensitive, ignore‐space |
| 43 | + (?:\*{1,2}|_{1,2}) # leading *…* or _…_ |
| 44 | + Answer[s]? # Answer or Answers |
| 45 | + \s*[:\-–]? # optional separator |
| 46 | + (?:\*{1,2}|_{1,2}) # closing wrapper |
| 47 | + \s* # optional space |
| 48 | + ([ABCD])\b # the actual letter |
| 49 | + ''', |
| 50 | + re.X |
| 51 | + ), |
| 52 | + |
| 53 | + # 0.1) |
| 54 | + re.compile(r'''(?ix) # ignore case, allow verbose mode |
| 55 | + ^\s* # optional leading whitespace |
| 56 | + (?:\*{1,2}|_{1,2})? # optional markdown wrapper |
| 57 | + Answer:? # the word 'answer' with an optional colon |
| 58 | + (?:\*{1,2}|_{1,2})? # optional markdown wrapper again |
| 59 | + \s*:?\s* # optional colon with optional spaces |
| 60 | + (?:\*{1,2}|_{1,2})? # optional markdown wrapper before letter |
| 61 | + ([ABCD]) # capture the letter |
| 62 | + (?:\*{1,2}|_{1,2})? # optional markdown wrapper after letter |
| 63 | + \s* # optional trailing whitespace, end of line |
| 64 | + ''', re.MULTILINE), |
| 65 | + |
| 66 | + # 1) Answer: (C) or Answers: (B) |
| 67 | + re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)'), |
| 68 | + |
| 69 | + # 2) Answer: C or Answers – D |
| 70 | + re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b'), |
| 71 | + |
| 72 | + # 3) Option B or Choice: C |
| 73 | + re.compile(r'(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b'), |
| 74 | + |
| 75 | + # 7) LaTeX \boxed{...A...}, catches both \boxed{A} and |
| 76 | + # \boxed{\text{A } 2.08\times10^{-6}\,\mathrm{m}} etc. |
| 77 | + re.compile(r'(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}', re.MULTILINE), |
| 78 | + |
| 79 | + # 7.5) LaTeX \boxed{\textbf{...C...}} |
| 80 | + re.compile(r'(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE), |
| 81 | + |
| 82 | + # 7.51) LaTeX \boxed{\text{...C...}} |
| 83 | + re.compile(r'(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE), |
| 84 | + |
| 85 | + # 4) bare singletons: (A) [B] |
| 86 | + re.compile(r'(?x)(?<![A-Za-z0-9])[\(\[]\s*([ABCD])\s*[\)\]](?![A-Za-z0-9])'), |
| 87 | + |
| 88 | + # 5) Markdown‐wrapped: *A* **B** _C_ __D__ |
| 89 | + re.compile(r'(?x)(?<![A-Za-z0-9])(?:\*{1,2}|_{1,2})([ABCD])(?:\*{1,2}|_{1,2})(?![A-Za-z0-9])'), |
| 90 | + |
| 91 | + # 6) LaTeX \textbf{...C...} |
| 92 | + re.compile(r'(?x)\\textbf\{[^}]*?([ABCD])[^}]*\}'), |
| 93 | + |
| 94 | + # 8) markdown‐wrapped answer plus “)” plus description, e.g. **D) …** |
| 95 | + re.compile(r'''(?x) # ignore whitespace in pattern |
| 96 | + (?<![A-Za-z0-9]) # not preceded by word‐char |
| 97 | + (?:\*{1,2}|_{1,2}) # opening ** or __ or * or _ |
| 98 | + \s*([ABCD])\) # capture letter plus “)” |
| 99 | + [^*_\n]+? # some text inside wrapper |
| 100 | + (?:\*{1,2}|_{1,2}) # closing wrapper |
| 101 | + (?![A-Za-z0-9]) # not followed by word‐char |
| 102 | + '''), |
| 103 | + |
| 104 | + # 9) final fallback: a line that's exactly "A", "B.", "C)", "**D**", etc. |
| 105 | + re.compile(r'''(?x)^\s* |
| 106 | + (?:\*{1,2}|_{1,2})? # optional markdown wrapper |
| 107 | + ([ABCD]) # capture group for letter |
| 108 | + (?:\*{1,2}|_{1,2})? # optional closing markdown |
| 109 | + \s*[\.\)\-–:]? # optional separator after the letter |
| 110 | + \s*.*$ # allow any following text |
| 111 | + ''', re.MULTILINE), |
| 112 | +] |
| 113 | + |
| 114 | + |
37 | 115 | # given command line arguments which includes a yaml file path, initialize a CePO configuration |
38 | 116 | def init_cepo_config(cmd_line_args: dict) -> CepoConfig: |
39 | 117 | # get the command line arguments |
@@ -463,121 +541,6 @@ def generate_single_plan(i): |
463 | 541 | print(f"\nCePO: Answer generated for one bestofn_n attempt.") |
464 | 542 |
|
465 | 543 | return final_output, completion_tokens, cb_log |
466 | | - |
467 | | - |
468 | | - # Log provider call if conversation logging is enabled |
469 | | - if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: |
470 | | - response_dict = response.model_dump() if hasattr(response, 'model_dump') else response |
471 | | - optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) |
472 | | - completion_tokens += response.usage.completion_tokens |
473 | | - |
474 | | - if response.choices[0].finish_reason == "length": |
475 | | - # Skipping plan generation due to exceeding the token budget. Usually it means the plan is incomplete. |
476 | | - continue |
477 | | - |
478 | | - # Step 2 - Execute the plan |
479 | | - content = f"Can you execute the above plan step-by-step to produce the final answer. "\ |
480 | | - f"Be extra careful when executing steps where your confidence is lower." |
481 | | - messages.extend([{"role": "assistant", "content": response.choices[0].message.content}, {"role": "user", "content": content}]) |
482 | | - |
483 | | - # Prepare request for logging |
484 | | - provider_request = { |
485 | | - "model": model, |
486 | | - "messages": messages, |
487 | | - "max_tokens": cepo_config.planning_max_tokens_step2, |
488 | | - "temperature": cepo_config.planning_temperature_step2, |
489 | | - "stream": False, |
490 | | - } |
491 | | - |
492 | | - response = client.chat.completions.create(**provider_request) |
493 | | - |
494 | | - # Log provider call if conversation logging is enabled |
495 | | - if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: |
496 | | - response_dict = response.model_dump() if hasattr(response, 'model_dump') else response |
497 | | - optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) |
498 | | - completion_tokens += response.usage.completion_tokens |
499 | | - |
500 | | - if response.choices[0].finish_reason == "length": |
501 | | - messages.append({"role": "assistant", "content": response.choices[0].message.content}) |
502 | | - cb_log[f"messages_planning_{i}_rejected_due_to_length"] = messages |
503 | | - if cepo_config.print_output: |
504 | | - print(f"\nCePO: Plan proposal rejected due to length. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}") |
505 | | - continue |
506 | | - |
507 | | - plans.append(response.choices[0].message.content) |
508 | | - messages.append({"role": "assistant", "content": response.choices[0].message.content}) |
509 | | - cb_log[f"messages_planning_{i}"] = messages |
510 | | - if cepo_config.print_output: |
511 | | - print(f"\nCePO: Plan proposal generated. Attempt {i + 1} out of {cepo_config.planning_m}.\nMessages: {messages}") |
512 | | - |
513 | | - if len(plans) == cepo_config.planning_n: |
514 | | - break |
515 | | - |
516 | | - if not plans: |
517 | | - # If no plans were generated succesfully, take the last one even if it was rejected due to length |
518 | | - plans.append(response.choices[0].message.content) |
519 | | - messages.append({"role": "assistant", "content": response.choices[0].message.content}) |
520 | | - cb_log[f"messages_planning_{i}_no_plans_so_taking_the_last_one"] = messages |
521 | | - if cepo_config.print_output: |
522 | | - print(f"\nCePO: No plans generated successfully. Taking the last one from rejected due to length.\nMessages: {messages}") |
523 | | - |
524 | | - # Step 3 - Review and address inconsistencies |
525 | | - try: |
526 | | - plans_message = "" |
527 | | - for i, plan in enumerate(plans): |
528 | | - plans_message += f"Response {i + 1}:\n{plan}\n\n" |
529 | | - plans_message = plans_message[:-2] # remove the last 2x newline |
530 | | - content = f"Can you review your last {len(plans)} responses and identify any inconsistency between them. After that, can you address "\ |
531 | | - f"it and present a final step-by-step solution to the problem? Here is the question:\n{question_only}" |
532 | | - messages = [{"role": "assistant", "content": plans_message}, {"role": "user", "content": content}] |
533 | | - |
534 | | - # Prepare request for logging |
535 | | - provider_request = { |
536 | | - "model": model, |
537 | | - "messages": messages, |
538 | | - "max_tokens": cepo_config.planning_max_tokens_step3, |
539 | | - "temperature": cepo_config.planning_temperature_step3, |
540 | | - "stream": False, |
541 | | - } |
542 | | - |
543 | | - response = client.chat.completions.create(**provider_request) |
544 | | - |
545 | | - # Log provider call if conversation logging is enabled |
546 | | - if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: |
547 | | - response_dict = response.model_dump() if hasattr(response, 'model_dump') else response |
548 | | - optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) |
549 | | - final_solution = response.choices[0].message.content |
550 | | - completion_tokens += response.usage.completion_tokens |
551 | | - except (CerebrasBadRequestError, OpenAIBadRequestError) as e: |
552 | | - # In case of an error, take the first plan as the final solution |
553 | | - final_solution = plans[0] |
554 | | - messages = [] |
555 | | - |
556 | | - # Step 4 - Answer the question |
557 | | - content = f"Use your final solution from above to correctly answer the question. Here is the question:\n{task}" |
558 | | - messages = [{"role": "assistant", "content": final_solution}, {"role": "user", "content": content}] |
559 | | - |
560 | | - # Prepare request for logging |
561 | | - provider_request = { |
562 | | - "model": model, |
563 | | - "messages": messages, |
564 | | - "max_tokens": cepo_config.planning_max_tokens_step4, |
565 | | - "temperature": cepo_config.planning_temperature_step4, |
566 | | - "stream": False, |
567 | | - } |
568 | | - |
569 | | - response = client.chat.completions.create(**provider_request) |
570 | | - |
571 | | - # Log provider call if conversation logging is enabled |
572 | | - if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: |
573 | | - response_dict = response.model_dump() if hasattr(response, 'model_dump') else response |
574 | | - optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) |
575 | | - completion_tokens += response.usage.completion_tokens |
576 | | - |
577 | | - cb_log["messages"] = messages |
578 | | - if cepo_config.print_output: |
579 | | - print(f"\nCePO: Answer generated.\nMessages: {messages}") |
580 | | - return response.choices[0].message.content, completion_tokens, cb_log |
581 | 544 |
|
582 | 545 |
|
583 | 546 | def generate_approaches(system_prompt: str, initial_query: str, num_approach: int, client: Any, model: str, cepo_config: CepoConfig, max_retry: int = 2, request_id: str = None) -> tuple[list[str], int]: |
@@ -877,6 +840,79 @@ def rate_completions_pairwise(system_prompt: str, initial_query: str, client: An |
877 | 840 | return completions[best_index], completion_tokens, cb_log |
878 | 841 |
|
879 | 842 |
|
| 843 | +def extract_answer_mathverify(response_str, last_n_chars=100): |
| 844 | + response_str = str(response_str) |
| 845 | + try: |
| 846 | + float(response_str) |
| 847 | + return [float(response_str)] |
| 848 | + except: |
| 849 | + response_str = response_str.split("</think>", 1)[1] if "</think>" in response_str else response_str |
| 850 | + if last_n_chars is not None: |
| 851 | + response_str = response_str[-last_n_chars:] |
| 852 | + parsed_result = math_verify.parse(response_str, parsing_timeout=None) |
| 853 | + return parsed_result |
| 854 | + |
| 855 | + |
| 856 | +def extract_abcd(text: str) -> str | None: |
| 857 | + """ |
| 858 | + Scan text (with Markdown/LaTeX wrappers intact) and return |
| 859 | + 'A', 'B', 'C', or 'D' if a correct-answer declaration is found. |
| 860 | + Otherwise return None. |
| 861 | + """ |
| 862 | + matches = [] |
| 863 | + for prio, pat in enumerate(MCQ_PATTERNS): |
| 864 | + m = pat.search(text) |
| 865 | + if m: |
| 866 | + letter = m.group(1).upper() |
| 867 | + if letter in 'ABCD': |
| 868 | + matches.append((prio, m, letter)) |
| 869 | + |
| 870 | + matches.sort(key=lambda triple: ( |
| 871 | + triple[0], |
| 872 | + len(triple[1].group(0)) |
| 873 | + )) |
| 874 | + for _, match, letter in matches: |
| 875 | + return letter |
| 876 | + return text.removeprefix('**')[:1] |
| 877 | + |
| 878 | + |
| 879 | +def majority_vote_math(completions, last_n_chars=100): |
| 880 | + extracted_answer_map = [] |
| 881 | + for response in completions: |
| 882 | + extracted_answer = extract_answer_mathverify(response, last_n_chars) |
| 883 | + extracted_answer = extracted_answer[0] if extracted_answer else None |
| 884 | + extracted_answer_map.append((response, extracted_answer)) |
| 885 | + |
| 886 | + counts = Counter(answer for _, answer in extracted_answer_map) |
| 887 | + majority_answer, count = counts.most_common(1)[0] |
| 888 | + # TODO it may return all "None", we probably should handle this case |
| 889 | + # Return one response whose extracted answer matches the majority |
| 890 | + for response, answer in extracted_answer_map: |
| 891 | + if answer == majority_answer: |
| 892 | + return response, count |
| 893 | + |
| 894 | + |
| 895 | +def majority_vote_mcq(completions, last_n_chars=100): |
| 896 | + extracted_answer_map = [] |
| 897 | + for response in completions: |
| 898 | + extracted_answer = extract_abcd(response[-last_n_chars:]) |
| 899 | + extracted_answer_map.append((response, extracted_answer)) |
| 900 | + |
| 901 | + counts = Counter(answer for _, answer in extracted_answer_map) |
| 902 | + majority_answer, count = counts.most_common(1)[0] |
| 903 | + # TODO it may return all "None", we probably should handle this case |
| 904 | + for response, answer in extracted_answer_map: |
| 905 | + if answer == majority_answer: |
| 906 | + return response, count |
| 907 | + |
| 908 | + |
| 909 | +def rate_completions_majority_vote(completions: list[str], last_n_chars: int = 150) -> tuple[str, int, dict]: |
| 910 | + mcq_majority, count = majority_vote_mcq(completions, last_n_chars) |
| 911 | + if mcq_majority is None: |
| 912 | + return majority_vote_math(completions, last_n_chars) |
| 913 | + return mcq_majority, count |
| 914 | + |
| 915 | + |
880 | 916 | def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_config: CepoConfig, request_id: str = None) -> tuple[str, int]: |
881 | 917 | """ |
882 | 918 | Applies CePO reasoning flow for the given task. First, it generates multiple completions, and then rates them to select the best one. |
@@ -907,14 +943,15 @@ def cepo(system_prompt: str, initial_query: str, client: Any, model: str, cepo_c |
907 | 943 | completions = [c for c in completions if c] # safeguard in case completion is None (observed with GPT OSS) |
908 | 944 |
|
909 | 945 | # Rate the completions |
| 946 | + rating_model = cepo_config.rating_model if cepo_config.rating_model else model |
910 | 947 | if cepo_config.bestofn_rating_type == "absolute": |
911 | | - rate_completions_fn = rate_completions_absolute |
| 948 | + best_completion, completion_tokens_rating, cb_log = rate_completions_absolute(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id) |
912 | 949 | elif cepo_config.bestofn_rating_type == "pairwise": |
913 | | - rate_completions_fn = rate_completions_pairwise |
| 950 | + best_completion, completion_tokens_rating, cb_log = rate_completions_pairwise(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id) |
| 951 | + elif cepo_config.bestofn_rating_type == "majority_with_code_exec": |
| 952 | + best_completion, _ = rate_completions_majority_vote(completions) |
| 953 | + completion_tokens_rating = 0 |
914 | 954 | else: |
915 | 955 | raise ValueError("Invalid rating type in cepo_config") |
916 | | - rating_model = cepo_config.rating_model if cepo_config.rating_model else model |
917 | | - |
918 | | - best_completion, completion_tokens_rating, cb_log = rate_completions_fn(system_prompt, initial_query, client, rating_model, completions, cepo_config, cb_log, request_id) |
919 | 956 |
|
920 | 957 | return best_completion, completion_tokens_planning + completion_tokens_rating |
0 commit comments