Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 284 additions & 0 deletions mellea/stdlib/sampling.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""sampling methods go here."""

import abc
import re
from asyncio import TaskGroup # type: ignore[attr-defined]
from collections import Counter
from collections.abc import Callable, Coroutine
from copy import deepcopy
from typing import Any

import numpy as np
import tqdm
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from rouge_score.rouge_scorer import RougeScorer # codespell:ignore

from mellea import LinearContext
from mellea.helpers.fancy_logger import FancyLogger
Expand Down Expand Up @@ -622,3 +628,281 @@ def repair(
repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}"
)
return past_actions[-1]


class BaseMBRDSampling(RejectionSamplingStrategy):
number_of_samples: int
weighted: bool
symmetric: bool

def __init__(
self,
*,
number_of_samples: int = 8,
weighted: bool = False,
loop_budget: int = 1,
validate: Callable[
[list[Requirement], Context, Any, Any],
Coroutine[Any, Any, list[ValidationResult]],
]
| None = None,
generate: (Callable[[Component, Context], ModelOutputThunk] | None) = None,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
loop_budget=loop_budget,
validate=validate,
generate=generate,
requirements=requirements,
)
self.number_of_samples = number_of_samples
self.weighted = weighted
self.symmetric = False

@abc.abstractmethod
def compare_strings(self, ref: str, pred: str) -> float:
"""This method is the abstract method for MBRD similarity metric."""

def maybe_apply_weighted(self, scr):
# TODO not implemented yet
if self.weighted:
weights = np.asarray([1.0 for _ in range(len(scr))])
scr = scr * weights

return scr

async def sample(
self,
action: Component,
context: Context,
requirements: list[Requirement],
*,
show_progress: bool = True,
validation_ctx: Context | None = None,
) -> SamplingResult:
# execute sampling concurrently
tasks = []
async with TaskGroup() as tg:
for i in range(self.number_of_samples):
task = tg.create_task(
super().sample(
action,
context,
requirements,
show_progress=show_progress,
validation_ctx=validation_ctx,
)
)
tasks.append(task)

# collect results
results = []
for task in tasks:
result = task.result()
if result.success:
output = str(result.result)
else:
# avoid type checker error
assert isinstance(result.sample_generations, list)
output = str(result.sample_generations[0].value)

results.append((output, result))

assert len(results) > 0

scr = np.asarray(
[[0.0 for _ in range(len(results))] for _ in range(len(results))]
)
for i in range(len(results)):
for j in range(len(results)):
if j == i:
scr[i][j] = 0.0 # self voting is 0.
continue

# upper triangle
# For sample i compute votes against all j references
if j > i:
scr[i][j] = float(
self.compare_strings(results[j][0], results[i][0])
)
continue

else:
if self.symmetric:
scr[i][j] = scr[j][i]
else:
scr[i][j] = float(
self.compare_strings(results[j][0], results[i][0])
)
continue

# count votes
scr = scr.sum(axis=0)

# Apply weights
scr = self.maybe_apply_weighted(scr)

maxR = int(scr.argmax())

return results[maxR][1] # return one of the MV answers


class MajorityVotingStrategyForMath(BaseMBRDSampling):
number_of_samples: int
match_types: list[str]
float_rounding: int
strict: bool
allow_set_relation_comp: bool
weighted: bool
symmetric: bool

def __init__(
self,
*,
number_of_samples: int = 8,
float_rounding: int = 6,
strict: bool = True,
allow_set_relation_comp: bool = False,
weighted: bool = False,
loop_budget: int = 1,
validate: Callable[
[list[Requirement], Context, Any, Any],
Coroutine[Any, Any, list[ValidationResult]],
]
| None = None,
generate: (Callable[[Component, Context], ModelOutputThunk] | None) = None,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
float_rounding: Number of decimal places to round floats to. Defaults to 6.
strict: Whether to enforce strict comparison mode. Defaults to True.
- In strict mode: Variables matter and sets are not comparable with tuples
- In non-strict mode: Variables are matched by position and sets can be compared with tuples
allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False.
- If True, set - relation comparison will be allowed in all cases.
- If False, set - relation comparison will be allowed only if the prediction is a set.
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
number_of_samples=number_of_samples,
weighted=weighted,
loop_budget=loop_budget,
validate=validate,
generate=generate,
requirements=requirements,
)
self.number_of_samples = number_of_samples
# match_type: type of match latex, expr (match only so far)
# - For math use "latex" or "expr" or both
# - For general text similarity use "rougel"
MATCH_TYPES = ["latex", "axpr"]
self.match_types = MATCH_TYPES
self.float_rounding = float_rounding
self.strict = strict
self.allow_set_relation_comp = allow_set_relation_comp
self.weighted = weighted

# Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183
self.symmetric = False

# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
def compare_strings(self, ref: str, pred: str):
"""Helper function to compare strings using the math extraction metrics"""
# Convert string match_types to ExtractionTarget objects
extraction_targets = []
for match_type in self.match_types:
if match_type == "latex":
extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0))
elif match_type == "expr":
extraction_targets.append(ExprExtractionConfig())

# NOTE: Math-Verify parse and verify functions don't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None
gold_parsed = parse(ref, extraction_targets, parsing_timeout=None)
pred_parsed = parse(pred, extraction_targets, parsing_timeout=None)
return verify(
gold_parsed,
pred_parsed,
float_rounding=self.float_rounding,
strict=self.strict,
allow_set_relation_comp=self.allow_set_relation_comp,
timeout_seconds=None,
)


class MBRDRougeLStrategy(BaseMBRDSampling):
number_of_samples: int
match_types: list[str]
weighted: bool
symmetric: bool
scorer: RougeScorer

def __init__(
self,
*,
number_of_samples: int = 8,
weighted: bool = False,
loop_budget: int = 1,
validate: Callable[
[list[Requirement], Context, Any, Any],
Coroutine[Any, Any, list[ValidationResult]],
]
| None = None,
generate: (Callable[[Component, Context], ModelOutputThunk] | None) = None,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
match_type: type of match latex, expr (match only so far)
- For math use "latex" or "expr" or both
- For general text similarity use "rougel"
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
number_of_samples=number_of_samples,
weighted=weighted,
loop_budget=loop_budget,
validate=validate,
generate=generate,
requirements=requirements,
)
self.number_of_samples = number_of_samples
self.match_types = ["rougeL"]
self.weighted = weighted
self.symmetric = True
self.scorer = RougeScorer(self.match_types, use_stemmer=True)

# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
def compare_strings(self, ref: str, pred: str):
"""Helper function to compare strings using the math extraction metrics"""

scr = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure
return scr
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ combine-as-imports = true
split-on-trailing-comma = false

[tool.codespell]
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot'
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,rouge,Rouge'
check-filenames = true
check-hidden = false
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
Expand Down
69 changes: 69 additions & 0 deletions test/stdlib_basics/test_majority_voting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import mellea
from mellea import MelleaSession
from mellea.backends.model_ids import OPENAI_GPT_OSS_20B, META_LLAMA_3_2_1B, IBM_GRANITE_4_TINY_PREVIEW_7B
from mellea.stdlib.base import CBlock, SimpleContext
from mellea.stdlib.requirement import check, req, simple_validate
from mellea.stdlib.sampling import MBRDRougeLStrategy, MajorityVotingStrategyForMath
from mellea.backends.openai import OpenAIBackend
from mellea.backends.formatter import TemplateFormatter
from transformers import AutoTokenizer
import pytest
import os


class TestMajorityVoting:
m = mellea.start_session(ctx=SimpleContext())

def test_majority_voting_for_math(self):

query = "Compute 1+1"
prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}."
prompt = query + prompt_suffix

result = self.m.instruct(
prompt,
# requirements=requirements,
strategy=MajorityVotingStrategyForMath(number_of_samples=8, loop_budget=1),
# user_variables={"name": name, "notes": notes},
return_sampling_results=True,
)
if result.success:
output = str(result.result)
else:
output = result.sample_generations[0].value

print(output)
assert output


def test_MBRDRougeL(self):

requirements = [
req("The email should have a salutation"), # == r1
req(
"Use only lower-case letters",
validation_fn=simple_validate(lambda x: x.lower() == x),
), # == r2
check("Do not mention purple elephants."), # == r3
]

name = "Olivia"
notes = "Olivia helped the lab over the last few weeks by organizing intern events, advertising the speaker series, and handling issues with snack delivery."
email_candidate = self.m.instruct(
"Write an email to {{name}} using the notes following: {{notes}}.",
requirements=requirements,
strategy=MBRDRougeLStrategy(number_of_samples=8, loop_budget=1),
user_variables={"name": name, "notes": notes},
return_sampling_results=True,
)

if email_candidate.success:
output = str(email_candidate.result)
else:
output = email_candidate.sample_generations[0].value

print(output)
assert output

if __name__ == "__main__":
pytest.main(["-s", __file__])