Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mellea/stdlib/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def _act(
)

# if there is no reason to sample, just generate from the context.
if strategy is None or requirements is None or len(requirements) == 0:
if strategy is None:
if strategy is None and requirements is not None:
FancyLogger.get_logger().warning(
"Calling the function with NO strategy BUT requirements. No requirement is being checked!"
Expand Down
2 changes: 1 addition & 1 deletion mellea/stdlib/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def sample(
action: Component,
context: Context,
backend: Backend,
requirements: list[Requirement],
requirements: list[Requirement] | None,
*,
validation_ctx: Context | None = None,
format: type[BaseModelSubclass] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion mellea/stdlib/sampling/best_of_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def sample(
action: Component,
context: Context,
backend: Backend,
requirements: list[Requirement],
requirements: list[Requirement] | None,
*,
validation_ctx: Context | None = None,
format: type[BaseModelSubclass] | None = None,
Expand Down
255 changes: 255 additions & 0 deletions mellea/stdlib/sampling/majority_voting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import abc
from asyncio import TaskGroup # type: ignore[attr-defined]

import numpy as np
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from rouge_score.rouge_scorer import RougeScorer # codespell:ignore

from mellea.backends import Backend, BaseModelSubclass
from mellea.stdlib.requirement import Requirement
from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult
from mellea.stdlib.sampling.base import Component, Context


class BaseMBRDSampling(RejectionSamplingStrategy):
number_of_samples: int
weighted: bool
symmetric: bool

def __init__(
self,
*,
number_of_samples: int = 8,
weighted: bool = False,
loop_budget: int = 1,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(loop_budget=loop_budget, requirements=requirements)
self.number_of_samples = number_of_samples
self.weighted = weighted
self.symmetric = True

@abc.abstractmethod
def compare_strings(self, ref: str, pred: str) -> float:
"""This method is the abstract method for MBRD similarity metric."""

def maybe_apply_weighted(self, scr):
# TODO not implemented yet
if self.weighted:
weights = np.asarray([1.0 for _ in range(len(scr))])
scr = scr * weights

return scr

async def sample(
self,
action: Component,
context: Context,
backend: Backend,
requirements: list[Requirement] | None,
*,
validation_ctx: Context | None = None,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
show_progress: bool = True,
) -> SamplingResult:
# execute sampling concurrently
tasks = []
async with TaskGroup() as tg:
for i in range(self.number_of_samples):
task = tg.create_task(
super().sample(
action,
context,
backend,
requirements,
validation_ctx=validation_ctx,
model_options=model_options,
tool_calls=tool_calls,
show_progress=show_progress,
)
)
tasks.append(task)

# collect results
results = []
for task in tasks:
result = task.result()
if result.success:
output = str(result.result)
else:
# avoid type checker error
assert isinstance(result.sample_generations, list)
output = str(result.sample_generations[0].value)

results.append((output, result))

assert len(results) > 0

scr = np.asarray(
[[0.0 for _ in range(len(results))] for _ in range(len(results))]
)
for i in range(len(results)):
for j in range(len(results)):
if j == i:
scr[i][j] = 1.0 # self voting is 1.
continue

# upper triangle
# For sample i compute votes against all j references
if j > i:
scr[i][j] = float(
self.compare_strings(results[j][0], results[i][0])
)
continue

else:
if self.symmetric:
scr[i][j] = scr[j][i]
else:
scr[i][j] = float(
self.compare_strings(results[j][0], results[i][0])
)
continue

# count votes
scr = scr.sum(axis=0)

# Apply weights
scr = self.maybe_apply_weighted(scr)

maxR = int(scr.argmax())

return results[maxR][1] # return one of the MV answers


class MajorityVotingStrategyForMath(BaseMBRDSampling):
number_of_samples: int
match_types: list[str]
float_rounding: int
strict: bool
allow_set_relation_comp: bool

def __init__(
self,
*,
number_of_samples: int = 8,
float_rounding: int = 6,
strict: bool = True,
allow_set_relation_comp: bool = False,
weighted: bool = False,
loop_budget: int = 1,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
float_rounding: Number of decimal places to round floats to. Defaults to 6.
strict: Whether to enforce strict comparison mode. Defaults to True.
- In strict mode: Variables matter and sets are not comparable with tuples
- In non-strict mode: Variables are matched by position and sets can be compared with tuples
allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False.
- If True, set - relation comparison will be allowed in all cases.
- If False, set - relation comparison will be allowed only if the prediction is a set.
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
number_of_samples=number_of_samples,
weighted=weighted,
loop_budget=loop_budget,
requirements=requirements,
)
self.number_of_samples = number_of_samples
# match_type: type of match latex, expr (match only so far)
# - For math use "latex" or "expr" or both
# - For general text similarity use "rougel"
MATCH_TYPES = ["latex", "axpr"]
self.match_types = MATCH_TYPES
self.float_rounding = float_rounding
self.strict = strict
self.allow_set_relation_comp = allow_set_relation_comp

# Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183
self.symmetric = True

# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
def compare_strings(self, ref: str, pred: str):
"""Helper function to compare strings using the math extraction metrics"""
# Convert string match_types to ExtractionTarget objects
extraction_targets = []
for match_type in self.match_types:
if match_type == "latex":
extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0))
elif match_type == "expr":
extraction_targets.append(ExprExtractionConfig())

# NOTE: Math-Verify parse and verify functions don't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None
gold_parsed = parse(ref, extraction_targets, parsing_timeout=None)
pred_parsed = parse(pred, extraction_targets, parsing_timeout=None)
return verify(
gold_parsed,
pred_parsed,
float_rounding=self.float_rounding,
strict=self.strict,
allow_set_relation_comp=self.allow_set_relation_comp,
timeout_seconds=None,
)


class MBRDRougeLStrategy(BaseMBRDSampling):
match_types: list[str]
scorer: RougeScorer

def __init__(
self,
*,
number_of_samples: int = 8,
weighted: bool = False,
loop_budget: int = 1,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
match_type: type of match latex, expr (match only so far)
- For math use "latex" or "expr" or both
- For general text similarity use "rougel"
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
number_of_samples=number_of_samples,
weighted=weighted,
loop_budget=loop_budget,
requirements=requirements,
)
self.match_types = ["rougeL"]
self.symmetric = True
self.scorer = RougeScorer(self.match_types, use_stemmer=True)

# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
def compare_strings(self, ref: str, pred: str):
"""Helper function to compare strings using the math extraction metrics"""

scr = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure
return scr
2 changes: 1 addition & 1 deletion mellea/stdlib/sampling/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def sample(
action: Component,
context: Context,
backend: Backend,
requirements: list[Requirement],
requirements: list[Requirement] | None,
*,
validation_ctx: Context | None = None,
format: type[BaseModelSubclass] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ combine-as-imports = true
split-on-trailing-comma = false

[tool.codespell]
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot'
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge'
check-filenames = true
check-hidden = false
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
Expand Down
63 changes: 63 additions & 0 deletions test/stdlib_basics/test_majority_voting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import mellea
from mellea.stdlib.base import SimpleContext
from mellea.stdlib.requirement import check, req, simple_validate
from mellea.stdlib.sampling.majority_voting import (
MBRDRougeLStrategy,
MajorityVotingStrategyForMath
)
import pytest


class TestMajorityVoting:

def test_majority_voting_for_math(self):
m = mellea.start_session(ctx=SimpleContext())
query = "Compute 1+1"
prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}."
prompt = query + prompt_suffix

result = m.instruct(
prompt,
strategy=MajorityVotingStrategyForMath(number_of_samples=8, loop_budget=1),
return_sampling_results=True,
)
if result.success:
output = str(result.result)
else:
output = result.sample_generations[0].value

print(output)
assert output

def test_MBRDRougeL(self):
m = mellea.start_session(ctx=SimpleContext())
requirements = [
req("The email should have a salutation"), # == r1
req(
"Use only lower-case letters",
validation_fn=simple_validate(lambda x: x.lower() == x),
), # == r2
check("Do not mention purple elephants."), # == r3
]

name = "Olivia"
notes = "Olivia helped the lab over the last few weeks by organizing intern events, advertising the speaker series, and handling issues with snack delivery."
email_candidate = m.instruct(
"Write an email to {{name}} using the notes following: {{notes}}.",
requirements=requirements,
strategy=MBRDRougeLStrategy(number_of_samples=8, loop_budget=1),
user_variables={"name": name, "notes": notes},
return_sampling_results=True,
)

if email_candidate.success:
output = str(email_candidate.result)
else:
output = email_candidate.sample_generations[0].value

print(output)
assert output


if __name__ == "__main__":
pytest.main(["-s", __file__])
Loading