Skip to content

Commit 36eaca4

Browse files
yelkurdiYousef El-KurdiYousef El-Kurdijakelorocco
authored
feat: majority voting sampling strategy (#142)
* initial commit for majority voting * adds .gitignore to test dir * removed naive exact-match and replaced it with Math-Verify * adjusted string comparison ref vs pred relation - fixed some type checking * adds MBRD rougeL similarity - refactors MBRD base class * fixes tests removes dependency on vLLM and uses default ollama * minor fix to pyproject.toml * allow self-voting, and default symmetry, removed redundant variables in devirved classes * Added special ficture for tests to run on github action runners * fix: minor tweaks to get automation to pass, remove redundant mellea checks, and improve readability --------- Co-authored-by: Yousef El-Kurdi <[email protected]> Co-authored-by: Yousef El-Kurdi <[email protected]> Co-authored-by: jakelorocco <[email protected]> Co-authored-by: jakelorocco <[email protected]>
1 parent 16ca219 commit 36eaca4

File tree

4 files changed

+447
-2
lines changed

4 files changed

+447
-2
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
"""Sampling Strategies for Minimum Bayes Risk Decoding (MBRD)."""
2+
3+
import abc
4+
import asyncio
5+
6+
import numpy as np
7+
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
8+
from rouge_score.rouge_scorer import RougeScorer # codespell:ignore
9+
10+
from mellea.backends import Backend, BaseModelSubclass
11+
from mellea.stdlib.requirement import Requirement
12+
from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult
13+
from mellea.stdlib.sampling.base import Component, Context
14+
15+
16+
class BaseMBRDSampling(RejectionSamplingStrategy):
17+
"""Abstract Minimum Bayes Risk Decoding (MBRD) Sampling Strategy."""
18+
19+
number_of_samples: int
20+
weighted: bool
21+
symmetric: bool
22+
23+
def __init__(
24+
self,
25+
*,
26+
number_of_samples: int = 8,
27+
weighted: bool = False,
28+
loop_budget: int = 1,
29+
requirements: list[Requirement] | None = None,
30+
):
31+
"""Initialize a new abstract Minimum Bayes Risk Decoding (MBRD) Sampling Strategy with default parameters.
32+
33+
Inherits from RejectionSamplingStrategy. Will generate up to loop_budget x number_of_samples requests. If no
34+
requirements are provided here or in sample(...), will only generate number_of_samples requests.
35+
36+
Classes that inherit from this must implement the `compare_strings` function.
37+
38+
Args:
39+
number_of_samples: Number of samples to generate and use for majority voting
40+
weighted: Not Implemented. If True, weights the score before getting the final majority vote
41+
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
42+
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.
43+
44+
Raises:
45+
AssertionError: If loop_budget is not greater than 0.
46+
"""
47+
super().__init__(loop_budget=loop_budget, requirements=requirements)
48+
self.number_of_samples = number_of_samples
49+
self.weighted = weighted
50+
self.symmetric = True
51+
52+
@abc.abstractmethod
53+
def compare_strings(self, ref: str, pred: str) -> float:
54+
"""This method is the abstract method for MBRD similarity metric."""
55+
56+
def maybe_apply_weighted(self, scr: np.ndarray):
57+
"""Applies weights if self.weighted is True. Not Implemented."""
58+
# TODO: not implemented yet
59+
if self.weighted:
60+
weights = np.asarray([1.0 for _ in range(len(scr))])
61+
scr = scr * weights
62+
63+
return scr
64+
65+
async def sample(
66+
self,
67+
action: Component,
68+
context: Context,
69+
backend: Backend,
70+
requirements: list[Requirement] | None,
71+
*,
72+
validation_ctx: Context | None = None,
73+
format: type[BaseModelSubclass] | None = None,
74+
model_options: dict | None = None,
75+
tool_calls: bool = False,
76+
show_progress: bool = True,
77+
) -> SamplingResult:
78+
"""Samples using majority voting.
79+
80+
Args:
81+
action : The action object to be sampled.
82+
context: The context to be passed to the sampling strategy.
83+
backend: The backend used for generating samples.
84+
requirements: List of requirements to test against (merged with global requirements).
85+
validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
86+
format: output format for structured outputs; ignored for this sampling strategy.
87+
model_options: model options to pass to the backend during generation / validation.
88+
tool_calls: True if tool calls should be used during this sampling strategy.
89+
show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
90+
91+
Returns:
92+
SamplingResult: A result object indicating the success or failure of the sampling process.
93+
"""
94+
# execute sampling concurrently
95+
tasks: list[asyncio.Task[SamplingResult]] = []
96+
for i in range(self.number_of_samples):
97+
task = asyncio.create_task(
98+
super().sample(
99+
action,
100+
context,
101+
backend,
102+
requirements,
103+
validation_ctx=validation_ctx,
104+
model_options=model_options,
105+
tool_calls=tool_calls,
106+
show_progress=show_progress,
107+
)
108+
)
109+
tasks.append(task)
110+
111+
sampling_results = await asyncio.gather(*tasks)
112+
113+
# collect results
114+
results: list[tuple[str, SamplingResult]] = []
115+
for result in sampling_results:
116+
output = str(result.result)
117+
results.append((output, result))
118+
assert len(results) > 0
119+
120+
# Create an array of len(results) x len(results) initialized to 0.0.
121+
scr = np.asarray(
122+
[[0.0 for _ in range(len(results))] for _ in range(len(results))]
123+
)
124+
for i in range(len(results)):
125+
for j in range(len(results)):
126+
if j == i:
127+
scr[i][j] = 1.0 # self voting is 1.
128+
continue
129+
130+
# upper triangle
131+
# For sample i compute votes against all j references
132+
if j > i:
133+
scr[i][j] = float(
134+
self.compare_strings(results[j][0], results[i][0])
135+
)
136+
continue
137+
138+
else:
139+
if self.symmetric:
140+
scr[i][j] = scr[j][i]
141+
else:
142+
scr[i][j] = float(
143+
self.compare_strings(results[j][0], results[i][0])
144+
)
145+
continue
146+
147+
# count votes
148+
summed_scr: np.ndarray = scr.sum(axis=0)
149+
150+
# Apply weights
151+
weighed_scr = self.maybe_apply_weighted(summed_scr)
152+
153+
maxR = int(weighed_scr.argmax())
154+
155+
return results[maxR][1] # return one of the MV answers
156+
157+
158+
class MajorityVotingStrategyForMath(BaseMBRDSampling):
159+
"""MajorityVoting Sampling Strategy for Math Expressions."""
160+
161+
number_of_samples: int
162+
match_types: list[str]
163+
float_rounding: int
164+
strict: bool
165+
allow_set_relation_comp: bool
166+
167+
def __init__(
168+
self,
169+
*,
170+
number_of_samples: int = 8,
171+
float_rounding: int = 6,
172+
strict: bool = True,
173+
allow_set_relation_comp: bool = False,
174+
weighted: bool = False,
175+
loop_budget: int = 1,
176+
requirements: list[Requirement] | None = None,
177+
):
178+
"""Initialize a new instance of MajorityVoting Sampling Strategy for Math with default parameters.
179+
180+
Will generate up to loop_budget x number_of_samples requests. If no
181+
requirements are provided here or in sample(...), will only generate number_of_samples requests.
182+
183+
Args:
184+
number_of_samples: Number of samples to generate and use for majority voting
185+
float_rounding: Number of decimal places to round floats to. Defaults to 6.
186+
strict: Whether to enforce strict comparison mode. Defaults to True.
187+
- In strict mode: Variables matter and sets are not comparable with tuples
188+
- In non-strict mode: Variables are matched by position and sets can be compared with tuples
189+
allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False.
190+
- If True, set - relation comparison will be allowed in all cases.
191+
- If False, set - relation comparison will be allowed only if the prediction is a set.
192+
weighted: Not Implemented. If True, weights the score before getting the final majority vote
193+
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
194+
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.
195+
196+
Raises:
197+
AssertionError: If loop_budget is not greater than 0.
198+
"""
199+
super().__init__(
200+
number_of_samples=number_of_samples,
201+
weighted=weighted,
202+
loop_budget=loop_budget,
203+
requirements=requirements,
204+
)
205+
self.number_of_samples = number_of_samples
206+
# match_type: type of match latex, expr (match only so far)
207+
# - For math use "latex" or "expr" or both
208+
# - For general text similarity use "rougel"
209+
MATCH_TYPES = ["latex", "axpr"]
210+
self.match_types = MATCH_TYPES
211+
self.float_rounding = float_rounding
212+
self.strict = strict
213+
self.allow_set_relation_comp = allow_set_relation_comp
214+
215+
# Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183
216+
self.symmetric = True
217+
218+
# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
219+
def compare_strings(self, ref: str, pred: str) -> float:
220+
"""Helper function to compare strings using the math extraction metrics."""
221+
# Convert string match_types to ExtractionTarget objects
222+
extraction_targets = []
223+
for match_type in self.match_types:
224+
if match_type == "latex":
225+
extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0))
226+
elif match_type == "expr":
227+
extraction_targets.append(ExprExtractionConfig())
228+
229+
# NOTE: Math-Verify parse and verify functions don't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None
230+
gold_parsed = parse(ref, extraction_targets, parsing_timeout=None) # type: ignore
231+
pred_parsed = parse(pred, extraction_targets, parsing_timeout=None) # type: ignore
232+
return float(
233+
verify(
234+
gold_parsed,
235+
pred_parsed,
236+
float_rounding=self.float_rounding,
237+
strict=self.strict,
238+
allow_set_relation_comp=self.allow_set_relation_comp,
239+
timeout_seconds=None,
240+
)
241+
)
242+
243+
244+
class MBRDRougeLStrategy(BaseMBRDSampling):
245+
"""Sampling Strategy that uses RougeL to compute symbol-level distances for majority voting."""
246+
247+
match_types: list[str]
248+
scorer: RougeScorer
249+
250+
def __init__(
251+
self,
252+
*,
253+
number_of_samples: int = 8,
254+
weighted: bool = False,
255+
loop_budget: int = 1,
256+
requirements: list[Requirement] | None = None,
257+
):
258+
"""Initialize a new instance of MBRDRougeL Sampling Strategy with default parameters.
259+
260+
Will generate up to loop_budget x number_of_samples requests. If no
261+
requirements are provided here or in sample(...), will only generate number_of_samples requests.
262+
263+
Args:
264+
number_of_samples: Number of samples to generate and use for majority voting
265+
weighted: Not Implemented. If True, weights the score before getting the final majority vote
266+
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
267+
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.
268+
269+
Raises:
270+
AssertionError: If loop_budget is not greater than 0.
271+
"""
272+
super().__init__(
273+
number_of_samples=number_of_samples,
274+
weighted=weighted,
275+
loop_budget=loop_budget,
276+
requirements=requirements,
277+
)
278+
self.match_types = ["rougeL"]
279+
self.symmetric = True
280+
self.scorer = RougeScorer(self.match_types, use_stemmer=True)
281+
282+
# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
283+
def compare_strings(self, ref: str, pred: str) -> float:
284+
"""Helper function to compare strings using the math extraction metrics."""
285+
scr: float = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure
286+
return scr

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ dependencies = [
4242
"mistletoe>=1.4.0",
4343
"huggingface-hub>=0.33.4",
4444
"pillow",
45+
"math_verify", # Needed for Majority Voting Sampling Strategies.
46+
"rouge_score" # Needed for Majority Voting Sampling Strategies.
4547
]
4648

4749
[project.scripts]
@@ -160,7 +162,7 @@ combine-as-imports = true
160162
split-on-trailing-comma = false
161163

162164
[tool.codespell]
163-
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot'
165+
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge'
164166
check-filenames = true
165167
check-hidden = false
166168
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from mellea.backends import ModelOption
2+
from mellea import start_session, MelleaSession
3+
from mellea.stdlib.requirement import check, req, simple_validate
4+
from mellea.stdlib.sampling.majority_voting import (
5+
MBRDRougeLStrategy,
6+
MajorityVotingStrategyForMath
7+
)
8+
import pytest
9+
10+
from mellea.stdlib.sampling.types import SamplingResult
11+
12+
13+
@pytest.fixture(scope="module")
14+
def m_session(gh_run):
15+
if gh_run == 1:
16+
m = start_session(
17+
"ollama",
18+
model_id="llama3.2:1b",
19+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
20+
)
21+
else:
22+
m = start_session(
23+
"ollama",
24+
model_id="llama3.2:1b",
25+
)
26+
yield m
27+
del m
28+
29+
30+
def test_majority_voting_for_math(m_session: MelleaSession):
31+
query = "Compute 1+1"
32+
prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}."
33+
prompt = query + prompt_suffix
34+
35+
result = m_session.instruct(
36+
prompt,
37+
strategy=MajorityVotingStrategyForMath(number_of_samples=8, loop_budget=1),
38+
return_sampling_results=True,
39+
)
40+
output = str(result.result)
41+
42+
print(output)
43+
assert output
44+
45+
46+
def test_MBRDRougeL(m_session: MelleaSession):
47+
requirements = [
48+
req("The email should have a salutation"), # == r1
49+
req(
50+
"Use only lower-case letters",
51+
validation_fn=simple_validate(lambda x: x.lower() == x),
52+
), # == r2
53+
check("Do not mention purple elephants."), # == r3
54+
]
55+
56+
name = "Olivia"
57+
notes = "Olivia helped the lab over the last few weeks by organizing intern events, advertising the speaker series, and handling issues with snack delivery."
58+
email_candidate: SamplingResult = m_session.instruct(
59+
"Write an email to {{name}} using the notes following: {{notes}}.",
60+
requirements=requirements, # type: ignore
61+
strategy=MBRDRougeLStrategy(number_of_samples=8, loop_budget=1),
62+
user_variables={"name": name, "notes": notes},
63+
return_sampling_results=True,
64+
)
65+
66+
output = str(email_candidate.result)
67+
68+
print(output)
69+
assert output
70+
71+
72+
if __name__ == "__main__":
73+
pytest.main(["-s", __file__])

0 commit comments

Comments
 (0)