Skip to content

Commit 0016daa

Browse files
committed
Add majority voting plugin for candidate selection
Introduces a plugin that generates multiple candidate solutions using the OpenAI API and selects the most frequent answer via majority voting. Includes answer extraction, normalization, and a summary of the voting process. Useful for tasks with discrete answers such as math, coding, and multiple choice problems.
1 parent 7904463 commit 0016daa

File tree

1 file changed

+271
-0
lines changed

1 file changed

+271
-0
lines changed
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""
2+
Majority Voting Plugin for OptILLM
3+
4+
This plugin implements a majority voting approach where k candidate solutions
5+
are generated and the most frequent answer is selected. This is particularly
6+
effective for problems with discrete answers (math, coding, multiple choice).
7+
8+
The plugin uses the OpenAI API's n parameter to generate multiple responses
9+
efficiently in a single API call.
10+
"""
11+
12+
import re
13+
import logging
14+
from typing import Tuple, Dict, Any, List, Optional
15+
from collections import Counter
16+
import json
17+
18+
logger = logging.getLogger(__name__)
19+
20+
# Plugin identifier
21+
SLUG = "majority_voting"
22+
23+
# Default number of candidates to generate
24+
DEFAULT_K = 6
25+
26+
# Default temperature for candidate generation
27+
DEFAULT_TEMPERATURE = 0.6
28+
29+
def extract_answer(text: str) -> Optional[str]:
30+
"""
31+
Extract the answer from a response text.
32+
33+
This function looks for common answer patterns in the response:
34+
1. Text after "Answer:" or "Final Answer:"
35+
2. Text within \\boxed{} (LaTeX format)
36+
3. Numbers at the end of the response
37+
4. The last line if it's short (likely the answer)
38+
39+
Args:
40+
text: The response text to extract answer from
41+
42+
Returns:
43+
The extracted answer or None if no clear answer found
44+
"""
45+
# Remove any trailing whitespace
46+
text = text.strip()
47+
48+
# Pattern 1: Look for "Answer:" or "Final Answer:" patterns
49+
answer_patterns = [
50+
r'(?:final\s+)?answer\s*[:=]\s*(.+?)(?:\n|$)',
51+
r'(?:the\s+)?(?:final\s+)?answer\s+is\s*[:=]?\s*(.+?)(?:\n|$)',
52+
r'(?:therefore|thus|so)\s*,?\s*(.+?)(?:\n|$)'
53+
]
54+
55+
for pattern in answer_patterns:
56+
match = re.search(pattern, text, re.IGNORECASE)
57+
if match:
58+
answer = match.group(1).strip()
59+
# Clean up the answer
60+
answer = answer.rstrip('.,;')
61+
if answer:
62+
logger.debug(f"Extracted answer using pattern: {answer}")
63+
return answer
64+
65+
# Pattern 2: Look for LaTeX boxed format
66+
boxed_match = re.search(r'\\boxed\{([^}]+)\}', text)
67+
if boxed_match:
68+
answer = boxed_match.group(1).strip()
69+
logger.debug(f"Extracted boxed answer: {answer}")
70+
return answer
71+
72+
# Pattern 3: Look for standalone numbers (useful for math problems)
73+
# Check the last few lines for a number
74+
lines = text.split('\n')
75+
for line in reversed(lines[-3:]): # Check last 3 lines
76+
line = line.strip()
77+
# Match numbers (including decimals, fractions, negative numbers)
78+
number_match = re.match(r'^-?\d+\.?\d*$|^-?\d+/\d+$', line)
79+
if number_match:
80+
logger.debug(f"Extracted number answer: {line}")
81+
return line
82+
83+
# Pattern 4: If the last line is short (< 50 chars), it might be the answer
84+
if lines:
85+
last_line = lines[-1].strip()
86+
if last_line and len(last_line) < 50 and not last_line.endswith(':'):
87+
logger.debug(f"Using last line as answer: {last_line}")
88+
return last_line
89+
90+
# Pattern 5: For multiple choice, look for single letter answers
91+
mc_match = re.search(r'\b([A-E])\b(?:\s*\))?$', text)
92+
if mc_match:
93+
answer = mc_match.group(1)
94+
logger.debug(f"Extracted multiple choice answer: {answer}")
95+
return answer
96+
97+
logger.warning("Could not extract a clear answer from the response")
98+
return None
99+
100+
def normalize_answer(answer: str) -> str:
101+
"""
102+
Normalize an answer for comparison.
103+
104+
This helps ensure that equivalent answers are treated as the same:
105+
- Converts to lowercase
106+
- Removes extra whitespace
107+
- Removes quotes
108+
- Normalizes number formats
109+
110+
Args:
111+
answer: The answer to normalize
112+
113+
Returns:
114+
The normalized answer
115+
"""
116+
# Convert to lowercase
117+
answer = answer.lower().strip()
118+
119+
# Remove quotes
120+
answer = answer.strip('"\'')
121+
122+
# Normalize whitespace
123+
answer = ' '.join(answer.split())
124+
125+
# Try to normalize numbers
126+
try:
127+
# Check if it's a float
128+
if '.' in answer:
129+
num = float(answer)
130+
# Format to remove trailing zeros
131+
answer = f"{num:g}"
132+
else:
133+
# Try integer
134+
num = int(answer)
135+
answer = str(num)
136+
except ValueError:
137+
# Not a number, keep as is
138+
pass
139+
140+
# Handle yes/no variations
141+
if answer in ['yes', 'yeah', 'yep', 'true', 'correct']:
142+
answer = 'yes'
143+
elif answer in ['no', 'nope', 'false', 'incorrect']:
144+
answer = 'no'
145+
146+
return answer
147+
148+
def run(
149+
system_prompt: str,
150+
initial_query: str,
151+
client,
152+
model: str,
153+
request_config: Dict[str, Any] = None
154+
) -> Tuple[str, int]:
155+
"""
156+
Main entry point for the majority voting plugin.
157+
158+
Generates k candidate solutions and returns the most frequent answer.
159+
160+
Args:
161+
system_prompt: System prompt for the model
162+
initial_query: User's query
163+
client: OpenAI-compatible client instance
164+
model: Model identifier
165+
request_config: Additional configuration parameters
166+
167+
Returns:
168+
Tuple of (response_text, completion_tokens_used)
169+
"""
170+
logger.info("Starting majority voting process")
171+
172+
# Extract parameters from request_config
173+
k = DEFAULT_K
174+
temperature = DEFAULT_TEMPERATURE
175+
176+
if request_config:
177+
k = request_config.get('k', DEFAULT_K)
178+
# Allow overriding temperature if needed
179+
temperature = request_config.get('temperature', DEFAULT_TEMPERATURE)
180+
# Respect max_tokens if provided
181+
max_tokens = request_config.get('max_tokens', 4096)
182+
else:
183+
max_tokens = 4096
184+
185+
logger.info(f"Generating {k} candidates with temperature={temperature}")
186+
187+
# Prepare messages
188+
messages = [
189+
{"role": "system", "content": system_prompt},
190+
{"role": "user", "content": initial_query}
191+
]
192+
193+
try:
194+
# Generate k candidates in a single API call using n parameter
195+
response = client.chat.completions.create(
196+
model=model,
197+
messages=messages,
198+
n=k,
199+
temperature=temperature,
200+
max_tokens=max_tokens
201+
)
202+
203+
# Extract all candidate responses
204+
candidates = [choice.message.content for choice in response.choices]
205+
total_tokens = response.usage.completion_tokens
206+
207+
logger.info(f"Generated {len(candidates)} candidates. Tokens used: {total_tokens}")
208+
209+
# Extract answers from each candidate
210+
answers = []
211+
answer_to_response = {} # Map normalized answers to full responses
212+
213+
for i, candidate in enumerate(candidates):
214+
answer = extract_answer(candidate)
215+
if answer:
216+
normalized = normalize_answer(answer)
217+
answers.append(normalized)
218+
# Keep the first full response for each unique answer
219+
if normalized not in answer_to_response:
220+
answer_to_response[normalized] = candidate
221+
logger.debug(f"Candidate {i+1} answer: {answer} (normalized: {normalized})")
222+
else:
223+
logger.warning(f"Could not extract answer from candidate {i+1}")
224+
225+
if not answers:
226+
logger.warning("No answers could be extracted from any candidate")
227+
# Return the first candidate as fallback
228+
return candidates[0] if candidates else "Error: No candidates generated", total_tokens
229+
230+
# Count answer frequencies
231+
answer_counts = Counter(answers)
232+
logger.info(f"Answer distribution: {dict(answer_counts)}")
233+
234+
# Get the most common answer
235+
most_common_answer, count = answer_counts.most_common(1)[0]
236+
confidence = count / len(answers)
237+
238+
logger.info(f"Most common answer: '{most_common_answer}' with {count}/{len(answers)} votes ({confidence:.1%} confidence)")
239+
240+
# Get the full response corresponding to the most common answer
241+
winning_response = answer_to_response.get(most_common_answer, candidates[0])
242+
243+
# Add voting summary to the response
244+
voting_summary = f"\n\n**Majority Voting Result**:\n"
245+
voting_summary += f"- Generated {k} candidates\n"
246+
voting_summary += f"- Most common answer: {most_common_answer}\n"
247+
voting_summary += f"- Votes: {count}/{len(answers)} ({confidence:.1%} confidence)\n"
248+
249+
if len(answer_counts) > 1:
250+
voting_summary += f"- Other answers: "
251+
other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer]
252+
voting_summary += ", ".join(other_answers)
253+
254+
# Return the full response from the winning answer with voting summary
255+
final_response = winning_response + voting_summary
256+
257+
return final_response, total_tokens
258+
259+
except Exception as e:
260+
logger.error(f"Error in majority voting: {str(e)}")
261+
# Fall back to single response
262+
logger.info("Falling back to single response generation")
263+
264+
response = client.chat.completions.create(
265+
model=model,
266+
messages=messages,
267+
temperature=temperature,
268+
max_tokens=max_tokens
269+
)
270+
271+
return response.choices[0].message.content, response.usage.completion_tokens

0 commit comments

Comments
 (0)