Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 274 additions & 0 deletions mellea/stdlib/sampling.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""sampling methods go here."""

import abc
import re
from collections import Counter
from collections.abc import Callable
from copy import deepcopy
from typing import Any

import numpy as np
import tqdm
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from rouge_score.rouge_scorer import RougeScorer # codespell:ignore

from mellea import LinearContext
from mellea.helpers.fancy_logger import FancyLogger
Expand Down Expand Up @@ -381,3 +386,272 @@ def repair(
)

return next_action


class BaseMBRDSampling(RejectionSamplingStrategy):
number_of_samples: int
weighted: bool
symmetric: bool

def __init__(
self,
*,
number_of_samples: int = 8,
weighted: bool = False,
loop_budget: int = 1,
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
| None = None,
generate: (
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
| None
) = None,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
loop_budget=loop_budget,
validate=validate,
generate=generate,
requirements=requirements,
)
self.number_of_samples = number_of_samples
self.weighted = weighted
self.symmetric = False

@abc.abstractmethod
def compare_strings(self, ref: str, pred: str) -> float:
"""This method is the abstract method for MBRD similarity metric."""

def maybe_apply_weighted(self, scr):
# TODO not implemented yet
if self.weighted:
weights = np.asarray([1.0 for _ in range(len(scr))])
scr = scr * weights

return scr

def sample(
self,
action: Component,
context: Context,
requirements: list[Requirement],
*,
show_progress: bool = True,
generate_logs: list[GenerateLog] | None = None,
validation_ctx: Context | None = None,
) -> SamplingResult:
results = []
# Generate samples
for i in range(self.number_of_samples):
result = super().sample(
action,
context,
requirements,
show_progress=show_progress,
generate_logs=generate_logs,
validation_ctx=validation_ctx,
)
if result.success:
output = str(result.result)
else:
# avoid type checker error
assert isinstance(result.sample_generations, list)
output = str(result.sample_generations[0].value)

results.append((output, result))

assert len(results) > 0

scr = np.asarray(
[[0.0 for _ in range(len(results))] for _ in range(len(results))]
)
for i in range(len(results)):
for j in range(len(results)):
if j == i:
scr[i][j] = 0.0 # self voting is 0.
continue

# upper triangle
# For sample i compute votes against all j references
if j > i:
scr[i][j] = float(
self.compare_strings(results[j][0], results[i][0])
)
continue

else:
if self.symmetric:
scr[i][j] = scr[j][i]
else:
scr[i][j] = float(
self.compare_strings(results[j][0], results[i][0])
)
continue

# count votes
scr = scr.sum(axis=0)

# Apply weights
scr = self.maybe_apply_weighted(scr)

maxR = int(scr.argmax())

return results[maxR][1] # return one of the MV answers


class MajorityVotingStrategyForMath(BaseMBRDSampling):
number_of_samples: int
match_types: list[str]
float_rounding: int
strict: bool
allow_set_relation_comp: bool
weighted: bool
symmetric: bool

def __init__(
self,
*,
number_of_samples: int = 8,
float_rounding: int = 6,
strict: bool = True,
allow_set_relation_comp: bool = False,
weighted: bool = False,
loop_budget: int = 1,
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
| None = None,
generate: (
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
| None
) = None,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
float_rounding: Number of decimal places to round floats to. Defaults to 6.
strict: Whether to enforce strict comparison mode. Defaults to True.
- In strict mode: Variables matter and sets are not comparable with tuples
- In non-strict mode: Variables are matched by position and sets can be compared with tuples
allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False.
- If True, set - relation comparison will be allowed in all cases.
- If False, set - relation comparison will be allowed only if the prediction is a set.
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
number_of_samples=number_of_samples,
weighted=weighted,
loop_budget=loop_budget,
validate=validate,
generate=generate,
requirements=requirements,
)
self.number_of_samples = number_of_samples
# match_type: type of match latex, expr (match only so far)
# - For math use "latex" or "expr" or both
# - For general text similarity use "rougel"
MATCH_TYPES = ["latex", "axpr"]
self.match_types = MATCH_TYPES
self.float_rounding = float_rounding
self.strict = strict
self.allow_set_relation_comp = allow_set_relation_comp
self.weighted = weighted

# Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183
self.symmetric = False

# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
def compare_strings(self, ref: str, pred: str):
"""Helper function to compare strings using the math extraction metrics"""
# Convert string match_types to ExtractionTarget objects
extraction_targets = []
for match_type in self.match_types:
if match_type == "latex":
extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0))
elif match_type == "expr":
extraction_targets.append(ExprExtractionConfig())

gold_parsed = parse(ref, extraction_targets)
pred_parsed = parse(pred, extraction_targets)
return verify(
gold_parsed,
pred_parsed,
float_rounding=self.float_rounding,
strict=self.strict,
allow_set_relation_comp=self.allow_set_relation_comp,
)


class MBRDRougeLStrategy(BaseMBRDSampling):
number_of_samples: int
match_types: list[str]
weighted: bool
symmetric: bool
scorer: RougeScorer

def __init__(
self,
*,
number_of_samples: int = 8,
weighted: bool = False,
loop_budget: int = 1,
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
| None = None,
generate: (
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
| None
) = None,
requirements: list[Requirement] | None = None,
):
"""Initialize a new instance of the class with default parameters.

Args:
number_of_samples: Number of samples to generate and use for majority voting
match_type: type of match latex, expr (match only so far)
- For math use "latex" or "expr" or both
- For general text similarity use "rougel"
loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0.
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.

Raises:
AssertionError: If loop_budget is not greater than 0.
"""
super().__init__(
number_of_samples=number_of_samples,
weighted=weighted,
loop_budget=loop_budget,
validate=validate,
generate=generate,
requirements=requirements,
)
self.number_of_samples = number_of_samples
self.match_types = ["rougeL"]
self.weighted = weighted
self.symmetric = True
self.scorer = RougeScorer(self.match_types, use_stemmer=True)

# https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36
def compare_strings(self, ref: str, pred: str):
"""Helper function to compare strings using the math extraction metrics"""

scr = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure
return scr
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ combine-as-imports = true
split-on-trailing-comma = false

[tool.codespell]
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd'
ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,rouge,Rouge'
check-filenames = true
check-hidden = false
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
Expand Down
2 changes: 2 additions & 0 deletions test/stdlib_basics/test_majority_voting/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
vllm.err
vllm.log
22 changes: 22 additions & 0 deletions test/stdlib_basics/test_majority_voting/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Test for OpenAI API served by VLLM

## Requirement

anaconda / miniconda / miniforge.

Make sure to run the test with multiple cores available (e.g. in a cloud instance / cluster job).
Although you may think 1 core is enough,
vllm could get stuck due to deadlock if so.

## Installation

Run the `install.sh` script, which needs to be done only once.
The script creates a new conda environment named "mellea_tbf" only for the purposes of testing or contributing to the think budget-forcing feature.

Run `./install.sh`

## Testing

``` shell
./run_test.sh
```
7 changes: 7 additions & 0 deletions test/stdlib_basics/test_majority_voting/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

name: mellea_mv
channels:
- conda-forge
dependencies:
- python=3.12 # note: at the time of writing, xformer (< vllm) has a broken wheel for 3.13. https://github.com/facebookresearch/xformers/issues/740#issuecomment-2753869337
- uv
10 changes: 10 additions & 0 deletions test/stdlib_basics/test_majority_voting/exec_sampling_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

source set_variables.sh

eval "$(conda shell.bash hook)"
conda activate $ENV_NAME

export LOCAL_TEST_MODEL

python test_majority_voting.py
24 changes: 24 additions & 0 deletions test/stdlib_basics/test_majority_voting/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash -xe

source set_variables.sh

conda env remove -y -n $ENV_NAME || true
conda env create -f $(readlink -f $(dirname $0))/environment.yml

in-conda (){
conda run -n $ENV_NAME $@
}


cd ../../../
in-conda uv pip install -e .
cd -
in-conda uv pip install pre-commit
in-conda uv pip install pytest
in-conda uv pip install vllm==0.10.0
in-conda uv pip install outlines
# in-conda uv pip install unsloth
in-conda uv pip install ipdb
in-conda uv pip install math-verify[antlr4_13_2]
in-conda uv pip install rouge-score

24 changes: 24 additions & 0 deletions test/stdlib_basics/test_majority_voting/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

source set_variables.sh

eval "$(conda shell.bash hook)"
conda activate $ENV_NAME

rm $VLLM_LOG $VLLM_ERR

bash ./serve.sh &
VLLM_PID=$!

trap "kill -SIGINT $VLLM_PID ; wait" EXIT

while sleep 1 ; do
if grep -q "Application startup complete." $VLLM_ERR
then
break
fi
done

bash exec_sampling_test.sh


Loading
Loading