Skip to content

Commit c5ca795

Browse files
committed
Update majority_voting_plugin.py
1 parent 351411b commit c5ca795

File tree

1 file changed

+60
-104
lines changed

1 file changed

+60
-104
lines changed

optillm/plugins/majority_voting_plugin.py

Lines changed: 60 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
"""
2-
Majority Voting Plugin V2 for OptILLM
2+
Majority Voting Plugin for OptILLM
33
4-
Enhanced version with:
5-
- Category-aware answer extraction
6-
- Adaptive temperature control
7-
- Improved answer normalization
8-
- Response quality filtering
9-
- Smart fallback strategies
4+
Generic implementation that generates multiple candidates and selects
5+
the most common response through simple voting.
106
"""
117

128
import re
139
import logging
1410
from typing import Tuple, Dict, Any, List, Optional
1511
from collections import Counter
16-
import json
1712

1813
logger = logging.getLogger(__name__)
1914

@@ -24,89 +19,58 @@
2419
DEFAULT_K = 8
2520
DEFAULT_TEMPERATURE = 0.6 # Unified temperature for consistency
2621

27-
def detect_category(query: str) -> str:
22+
23+
def normalize_response(response: str) -> str:
2824
"""
29-
Try to detect the problem category from the query.
30-
31-
Returns:
32-
Category string or 'default' if unknown
25+
Basic normalization for comparing responses.
26+
Removes extra whitespace, punctuation at ends, and lowercases.
3327
"""
34-
query_lower = query.lower()
28+
if not response:
29+
return ""
3530

36-
# GSM8K patterns
37-
if "###" in query or ("calculate" in query_lower and any(word in query_lower for word in ["total", "sum", "difference", "product"])):
38-
return "gsm8k"
31+
# Remove thinking blocks if present
32+
response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
3933

40-
# MMLU patterns (multiple choice)
41-
if re.search(r'\b[A-E]\s*[:\)]\s*', query) or "which of the following" in query_lower:
42-
return "mmlu_math"
34+
# Basic normalization
35+
response = response.strip()
36+
response = response.lower()
4337

44-
# BoolQ patterns
45-
if query_lower.strip().endswith("?") and any(word in query_lower for word in ["is", "are", "was", "were", "does", "do", "did", "can", "could", "will", "would"]):
46-
return "boolq"
38+
# Remove trailing punctuation
39+
response = response.rstrip('.,;:!?')
4740

48-
# AQUA-RAT patterns
49-
if re.search(r'options?:\s*[A-E]', query, re.IGNORECASE):
50-
return "aqua_rat"
41+
# Normalize whitespace
42+
response = ' '.join(response.split())
5143

52-
return "default"
44+
return response
5345

5446

55-
56-
57-
def extract_answer_simple(response: str, category: str) -> Optional[str]:
47+
def extract_final_answer(response: str) -> str:
5848
"""
59-
Extract answer using same logic as evaluation script for consistency.
49+
Try to extract just the final answer from a response.
50+
This is generic and looks for common patterns.
6051
"""
6152
if not response:
62-
return None
53+
return response
6354

64-
# Remove thinking blocks if present
55+
# Remove thinking blocks
6556
response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL).strip()
6657

67-
if category == "gsm8k":
68-
# Extract number after ###
69-
match = re.search(r'###\s*(-?\d*\.?\d+)', response)
58+
# Look for common answer patterns
59+
patterns = [
60+
r'(?:final answer|answer):\s*(.+?)(?:\n|$)',
61+
r'(?:the answer is|answer is)\s*(.+?)(?:\n|$)',
62+
r'###\s*(.+?)(?:\n|$)', # Common in math problems
63+
r'^([A-E])\b', # Single letter at start
64+
r'\b([A-E])\b\s*$', # Single letter at end
65+
]
66+
67+
for pattern in patterns:
68+
match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
7069
if match:
71-
return match.group(1)
70+
return match.group(1).strip()
7271

73-
elif category == "aqua_rat":
74-
# For AQUA-RAT, be more flexible in extraction
75-
response_upper = response.upper()
76-
77-
# Try to find letter choices (A-E)
78-
patterns = [
79-
r'\b([A-E])\b(?!\w)', # Single letter not part of word
80-
r'(?:answer|choice|option)\s*:?\s*([A-E])\b',
81-
r'\(([A-E])\)', # Letter in parentheses
82-
r'^([A-E])$', # Just the letter
83-
]
84-
85-
for pattern in patterns:
86-
match = re.search(pattern, response_upper, re.IGNORECASE | re.MULTILINE)
87-
if match:
88-
return match.group(1)
89-
90-
# If no letter found, check for common wrong patterns
91-
# Map true/false/yes/no/numbers to letters (this is a heuristic)
92-
if re.search(r'\b(true|yes|1)\b', response.lower()):
93-
return "A" # Default mapping
94-
elif re.search(r'\b(false|no|0)\b', response.lower()):
95-
return "B" # Default mapping
96-
97-
elif category == "boolq":
98-
response_lower = response.lower()
99-
if 'yes' in response_lower:
100-
return 'yes'
101-
elif 'no' in response_lower:
102-
return 'no'
103-
104-
elif category == "mmlu_math":
105-
# For MMLU, just return the cleaned response
106-
return response.strip()
107-
108-
# Default: return cleaned response
109-
return response.strip()
72+
# If no pattern found, return the whole response
73+
return response
11074

11175

11276
def run(
@@ -117,20 +81,16 @@ def run(
11781
request_config: Dict[str, Any] = None
11882
) -> Tuple[str, int]:
11983
"""
120-
Simplified majority voting using consistent evaluation logic.
84+
Generic majority voting implementation.
12185
"""
12286
logger.info("Starting majority voting process")
12387

124-
# Detect category
125-
category = detect_category(initial_query)
126-
logger.info(f"Detected category: {category}")
127-
12888
# Extract parameters
12989
k = request_config.get('k', DEFAULT_K) if request_config else DEFAULT_K
13090
temperature = request_config.get('temperature', DEFAULT_TEMPERATURE) if request_config else DEFAULT_TEMPERATURE
13191
max_tokens = request_config.get('max_tokens', 4096) if request_config else 4096
13292

133-
logger.info(f"Generating {k} candidates with temperature={temperature} for category={category}")
93+
logger.info(f"Generating {k} candidates with temperature={temperature}")
13494

13595
# Prepare messages
13696
messages = [
@@ -175,40 +135,36 @@ def run(
175135
if not candidates:
176136
return "Error: Could not generate any candidates", 0
177137

178-
# Extract answers and count votes
138+
# Extract and normalize answers for voting
179139
answer_votes = Counter()
180140
answer_to_responses = {}
181141

182142
for i, candidate in enumerate(candidates):
183-
answer = extract_answer_simple(candidate, category)
184-
if answer:
185-
# Normalize answer for voting
186-
if category == "aqua_rat":
187-
answer = answer.upper() # Ensure letters are uppercase
188-
elif category == "boolq":
189-
answer = answer.lower() # Ensure yes/no are lowercase
190-
elif category == "gsm8k":
191-
# Try to normalize numbers
192-
try:
193-
answer = str(float(answer))
194-
except:
195-
pass
143+
# Try to extract just the answer part
144+
answer = extract_final_answer(candidate)
145+
146+
# Normalize for comparison
147+
normalized = normalize_response(answer)
148+
149+
if normalized:
150+
answer_votes[normalized] += 1
151+
152+
# Keep track of original responses for each normalized answer
153+
if normalized not in answer_to_responses:
154+
answer_to_responses[normalized] = []
155+
answer_to_responses[normalized].append(candidate)
196156

197-
answer_votes[answer] += 1
198-
if answer not in answer_to_responses:
199-
answer_to_responses[answer] = []
200-
answer_to_responses[answer].append(candidate)
201-
logger.debug(f"Candidate {i+1}: extracted '{answer}'")
157+
logger.debug(f"Candidate {i+1}: '{answer}' -> '{normalized}'")
202158
else:
203-
logger.warning(f"Could not extract answer from candidate {i+1}")
159+
logger.warning(f"Could not extract/normalize answer from candidate {i+1}")
204160

205161
# Select the most voted answer
206162
if answer_votes:
207-
most_common_answer, count = answer_votes.most_common(1)[0]
208-
logger.info(f"Most common answer: '{most_common_answer}' with {count}/{k} votes")
163+
most_common_normalized, count = answer_votes.most_common(1)[0]
164+
logger.info(f"Most common answer: '{most_common_normalized}' with {count}/{k} votes")
209165

210-
# Return the first response that gave this answer
211-
winning_responses = answer_to_responses[most_common_answer]
166+
# Return the first original response that mapped to this answer
167+
winning_responses = answer_to_responses[most_common_normalized]
212168
return winning_responses[0], total_tokens
213169
else:
214170
# If no answers could be extracted, return the first candidate

0 commit comments

Comments
 (0)