From 1cfb51613be19400b51271f15d4aeb93fd6f6ea7 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Mon, 15 Sep 2025 14:07:16 +0000 Subject: [PATCH 01/10] initial commit for majority voting --- mellea/stdlib/sampling.py | 93 +++++++++++++++ .../test_majority_voting/README.md | 22 ++++ .../test_majority_voting/environment.yml | 7 ++ .../exec_sampling_test.sh | 10 ++ .../test_majority_voting/install.sh | 22 ++++ .../test_majority_voting/run_test.sh | 24 ++++ .../test_majority_voting/serve.sh | 16 +++ .../test_majority_voting/set_variables.sh | 8 ++ .../test_majority_voting.py | 107 ++++++++++++++++++ .../test_majority_voting/vllm.err | 24 ++++ 10 files changed, 333 insertions(+) create mode 100644 test/stdlib_basics/test_majority_voting/README.md create mode 100644 test/stdlib_basics/test_majority_voting/environment.yml create mode 100644 test/stdlib_basics/test_majority_voting/exec_sampling_test.sh create mode 100755 test/stdlib_basics/test_majority_voting/install.sh create mode 100755 test/stdlib_basics/test_majority_voting/run_test.sh create mode 100755 test/stdlib_basics/test_majority_voting/serve.sh create mode 100644 test/stdlib_basics/test_majority_voting/set_variables.sh create mode 100644 test/stdlib_basics/test_majority_voting/test_majority_voting.py create mode 100644 test/stdlib_basics/test_majority_voting/vllm.err diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index ee6ab431..318d20fd 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -1,6 +1,8 @@ """sampling methods go here.""" import abc +import re +from collections import Counter from collections.abc import Callable from copy import deepcopy from typing import Any @@ -381,3 +383,94 @@ def repair( ) return next_action + + +class MajorityVotingStrategyForMath(RejectionSamplingStrategy): + number_of_samples: int + answer_extraction_regex: str + + def __init__( + self, + *, + number_of_samples: int = 8, + answer_extraction_regex: str = r"\\boxed{(.*?)}", + loop_budget: int = 1, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + requirements: list[Requirement] | None = None, + ): + """Initialize a new instance of the class with default parameters. + + Args: + number_of_samples: Number of samples to generate and use for majority voting + loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. + validate: Function to validate the results against requirements. If None, validation is provided later through setter. + generate: Function to generate new model output thunks. If None, generate is provided later through setter. + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. + + Raises: + AssertionError: If loop_budget is not greater than 0. + """ + super().__init__( + loop_budget=loop_budget, + validate=validate, + generate=generate, + requirements=requirements, + ) + self.number_of_samples = number_of_samples + self.answer_extraction_regex = answer_extraction_regex + + def answer_extraction(self, response): + matches = re.findall(self.answer_extraction_regex, response, re.DOTALL) + if len(matches) > 0: + return matches[-1] # return the last match + else: + return "" + + def format_math(self, response): + # TODO implement + return response + + def sample( + self, + action: Component, + context: Context, + requirements: list[Requirement], + *, + show_progress: bool = True, + generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, + ) -> SamplingResult: + results = dict() + # Generate samples + for i in range(self.number_of_samples): + result = super().sample( + action, + context, + requirements, + show_progress=show_progress, + generate_logs=generate_logs, + validation_ctx=validation_ctx, + ) + if result.success: + output = str(result.result) + else: + output = result.sample_generations[0].value + + answer = self.answer_extraction(output) + answer = self.format_math(answer) + if answer in results: + results[answer].append(result) + else: + results[answer] = [result] + + assert len(results) > 0 + + # obtain majority voting answer + counts = Counter(results.keys()) + ans, cnt = counts.most_common(1)[0] + return results[ans][0] # return one of the MV answers diff --git a/test/stdlib_basics/test_majority_voting/README.md b/test/stdlib_basics/test_majority_voting/README.md new file mode 100644 index 00000000..c904b0d6 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/README.md @@ -0,0 +1,22 @@ +# Test for OpenAI API served by VLLM + +## Requirement + +anaconda / miniconda / miniforge. + +Make sure to run the test with multiple cores available (e.g. in a cloud instance / cluster job). +Although you may think 1 core is enough, +vllm could get stuck due to deadlock if so. + +## Installation + +Run the `install.sh` script, which needs to be done only once. +The script creates a new conda environment named "mellea_tbf" only for the purposes of testing or contributing to the think budget-forcing feature. + +Run `./install.sh` + +## Testing + +``` shell +./run_test.sh +``` diff --git a/test/stdlib_basics/test_majority_voting/environment.yml b/test/stdlib_basics/test_majority_voting/environment.yml new file mode 100644 index 00000000..626ad888 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/environment.yml @@ -0,0 +1,7 @@ + +name: mellea_mv +channels: + - conda-forge +dependencies: + - python=3.12 # note: at the time of writing, xformer (< vllm) has a broken wheel for 3.13. https://github.com/facebookresearch/xformers/issues/740#issuecomment-2753869337 + - uv diff --git a/test/stdlib_basics/test_majority_voting/exec_sampling_test.sh b/test/stdlib_basics/test_majority_voting/exec_sampling_test.sh new file mode 100644 index 00000000..d543632c --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/exec_sampling_test.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +source set_variables.sh + +eval "$(conda shell.bash hook)" +conda activate $ENV_NAME + +export LOCAL_TEST_MODEL + +python test_majority_voting.py diff --git a/test/stdlib_basics/test_majority_voting/install.sh b/test/stdlib_basics/test_majority_voting/install.sh new file mode 100755 index 00000000..0ce37ba1 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/install.sh @@ -0,0 +1,22 @@ +#!/bin/bash -xe + +source set_variables.sh + +conda env remove -y -n $ENV_NAME || true +conda env create -f $(readlink -f $(dirname $0))/environment.yml + +in-conda (){ + conda run -n $ENV_NAME $@ +} + + +cd ../../../ +in-conda uv pip install -e . +cd - +in-conda uv pip install pre-commit +in-conda uv pip install pytest +in-conda uv pip install vllm==0.10.0 +in-conda uv pip install outlines +# in-conda uv pip install unsloth +in-conda uv pip install ipdb + diff --git a/test/stdlib_basics/test_majority_voting/run_test.sh b/test/stdlib_basics/test_majority_voting/run_test.sh new file mode 100755 index 00000000..9e17b492 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/run_test.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +source set_variables.sh + +eval "$(conda shell.bash hook)" +conda activate $ENV_NAME + +rm $VLLM_LOG $VLLM_ERR + +bash ./serve.sh & +VLLM_PID=$! + +trap "kill -SIGINT $VLLM_PID ; wait" EXIT + +while sleep 1 ; do + if grep -q "Application startup complete." $VLLM_ERR + then + break + fi +done + +bash exec_sampling_test.sh + + diff --git a/test/stdlib_basics/test_majority_voting/serve.sh b/test/stdlib_basics/test_majority_voting/serve.sh new file mode 100755 index 00000000..a890a211 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/serve.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +source set_variables.sh +eval "$(conda shell.bash hook)" +conda activate $ENV_NAME +export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True + +echo "launching a vllm server. Logs are found in $(readlink -ef $(dirname $0))/vllm.log" + # At the time of writing this code, Granite 4.4 vLLM serving did not support prefix-caching + # --enable-prefix-caching \ +vllm serve $LOCAL_TEST_MODEL \ + --dtype bfloat16 \ + > $VLLM_LOG \ + 2> $VLLM_ERR + + diff --git a/test/stdlib_basics/test_majority_voting/set_variables.sh b/test/stdlib_basics/test_majority_voting/set_variables.sh new file mode 100644 index 00000000..6971f487 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/set_variables.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +PYTHONBREAKPOINT="ipdb.set_trace" +LOCAL_TEST_MODEL="ibm-granite/granite-4.0-tiny-preview" +ENV_NAME=mellea_mv +DIR=$(readlink -ef $(dirname $0)) +VLLM_LOG=$DIR/vllm.log +VLLM_ERR=$DIR/vllm.err diff --git a/test/stdlib_basics/test_majority_voting/test_majority_voting.py b/test/stdlib_basics/test_majority_voting/test_majority_voting.py new file mode 100644 index 00000000..6f0804c1 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/test_majority_voting.py @@ -0,0 +1,107 @@ +import mellea +from mellea import MelleaSession +from mellea.backends.model_ids import OPENAI_GPT_OSS_20B, META_LLAMA_3_2_1B, IBM_GRANITE_4_TINY_PREVIEW_7B +from mellea.stdlib.base import CBlock, SimpleContext +from mellea.stdlib.requirement import check, req, simple_validate +from mellea.stdlib.sampling import RejectionSamplingStrategy, MajorityVotingStrategyForMath +from mellea.backends.openai import OpenAIBackend +from mellea.backends.formatter import TemplateFormatter +from transformers import AutoTokenizer +import pytest +import os + + +class TestMajorityVoting: + MODEL_ID = os.environ.get("LOCAL_TEST_MODEL", None) + if MODEL_ID is None: + raise RuntimeError(f"Must set environment variable `LOCAL_TEST_MODEL` to a HF model id") + + # Local testing mode + if MODEL_ID == "ibm-granite/granite-4.0-tiny-preview": + MODEL_ID = IBM_GRANITE_4_TINY_PREVIEW_7B + + elif MODEL_ID == "unsloth/Llama-3.2-1B": + MODEL_ID = META_LLAMA_3_2_1B + + else: + raise RuntimeError(f"Unsupported model-id:{MODEL_ID}") + + model_id = "ibm-granite/granite-4.0-tiny-preview" + backend = OpenAIBackend( + model_id=MODEL_ID, + formatter=TemplateFormatter(model_id=MODEL_ID), + base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:8000')}/v1", + api_key="ollama", + ) + + m = MelleaSession(backend, ctx=SimpleContext()) + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID.hf_model_name, trust_remote_code=True) + + + def prepare_prmpt_for_math(self, query): + # Preparing prompt for math reasoning tasks + system_prompt = None # Use default of chat template + prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}." + + if prompt_suffix: + query += prompt_suffix + + msg = [] + if system_prompt is not None: + msg.append({"role": "system", "content": system_prompt}) + + msg.append({"role": "user", "content": query}) + if self.tokenizer.chat_template is None: + raise RuntimeError(f"No explicit chat template is defined for model-id: ") + + else: + prompt = self.tokenizer.apply_chat_template( + msg, + tokenize=False, + thinking=True, + add_generation_prompt=True, + ) + + return prompt + + def test_majority_voting(self): + + query = "Compute 1+1" + prompt = self.prepare_prmpt_for_math(query) + + # requirements = [ + # req("The email should have a salutation"), # == r1 + # req( + # "Use only lower-case letters", + # validation_fn=simple_validate(lambda x: x.lower() == x), + # ), # == r2 + # check("Do not mention purple elephants."), # == r3 + # ] + + # def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str: + # breakpoint() + # email_candidate = m.instruct( + # "Write an email to {{name}} using the notes following: {{notes}}.", + # # requirements=requirements, + # strategy=RejectionSamplingStrategy(loop_budget=5), + # user_variables={"name": name, "notes": notes}, + # return_sampling_results=True, + # ) + result = self.m.instruct( + prompt, + # requirements=requirements, + strategy=MajorityVotingStrategyForMath(number_of_samples=8, loop_budget=1), + # user_variables={"name": name, "notes": notes}, + return_sampling_results=True, + ) + if result.success: + output = str(result.result) + else: + output = result.sample_generations[0].value + print(output) + + + # assert gen_tok_cnt <= 2 * THINK_MAX_TOKENS + +if __name__ == "__main__": + pytest.main(["-s", __file__]) diff --git a/test/stdlib_basics/test_majority_voting/vllm.err b/test/stdlib_basics/test_majority_voting/vllm.err new file mode 100644 index 00000000..2ed420a6 --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/vllm.err @@ -0,0 +1,24 @@ +`torch_dtype` is deprecated! Use `dtype` instead! + Parse safetensors files: 0%| | 0/3 [00:00 +Traceback (most recent call last): + File "/proj/rh-inf-scaling/yelkurdi/py_envs/mellea_mv/lib/python3.12/site-packages/zmq/sugar/socket.py", line 184, in __del__ + def __del__(self): + + File "/proj/rh-inf-scaling/yelkurdi/py_envs/mellea_mv/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 434, in signal_handler + raise KeyboardInterrupt("MQLLMEngine terminated") +KeyboardInterrupt: MQLLMEngine terminated +[rank0]:[W915 07:15:22.542586008 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) +INFO: Shutting down +INFO: Waiting for application shutdown. +INFO: Application shutdown complete. From a7f3a7d81971a9fda7f00cc2a3f7b1cd1e692427 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Mon, 15 Sep 2025 14:42:07 +0000 Subject: [PATCH 02/10] adds .gitignore to test dir --- .../test_majority_voting/.gitignore | 2 ++ .../test_majority_voting/vllm.err | 24 ------------------- 2 files changed, 2 insertions(+), 24 deletions(-) create mode 100644 test/stdlib_basics/test_majority_voting/.gitignore delete mode 100644 test/stdlib_basics/test_majority_voting/vllm.err diff --git a/test/stdlib_basics/test_majority_voting/.gitignore b/test/stdlib_basics/test_majority_voting/.gitignore new file mode 100644 index 00000000..baf1455b --- /dev/null +++ b/test/stdlib_basics/test_majority_voting/.gitignore @@ -0,0 +1,2 @@ +vllm.err +vllm.log diff --git a/test/stdlib_basics/test_majority_voting/vllm.err b/test/stdlib_basics/test_majority_voting/vllm.err deleted file mode 100644 index 2ed420a6..00000000 --- a/test/stdlib_basics/test_majority_voting/vllm.err +++ /dev/null @@ -1,24 +0,0 @@ -`torch_dtype` is deprecated! Use `dtype` instead! - Parse safetensors files: 0%| | 0/3 [00:00 -Traceback (most recent call last): - File "/proj/rh-inf-scaling/yelkurdi/py_envs/mellea_mv/lib/python3.12/site-packages/zmq/sugar/socket.py", line 184, in __del__ - def __del__(self): - - File "/proj/rh-inf-scaling/yelkurdi/py_envs/mellea_mv/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 434, in signal_handler - raise KeyboardInterrupt("MQLLMEngine terminated") -KeyboardInterrupt: MQLLMEngine terminated -[rank0]:[W915 07:15:22.542586008 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) -INFO: Shutting down -INFO: Waiting for application shutdown. -INFO: Application shutdown complete. From 19e36e5df089841839e17077434a13ae8c01ce6e Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Tue, 16 Sep 2025 04:37:32 +0000 Subject: [PATCH 03/10] removed naive exact-match and replaced it with Math-Verify --- mellea/stdlib/sampling.py | 118 ++++++++++++++---- .../test_majority_voting/install.sh | 1 + 2 files changed, 93 insertions(+), 26 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 318d20fd..da319357 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -7,7 +7,9 @@ from copy import deepcopy from typing import Any +import numpy as np import tqdm +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from mellea import LinearContext from mellea.helpers.fancy_logger import FancyLogger @@ -387,13 +389,22 @@ def repair( class MajorityVotingStrategyForMath(RejectionSamplingStrategy): number_of_samples: int - answer_extraction_regex: str + match_types: list[str] + float_rounding: int + strict: bool + allow_set_relation_comp: bool + weighted: bool + symmetric: bool def __init__( self, *, number_of_samples: int = 8, - answer_extraction_regex: str = r"\\boxed{(.*?)}", + match_types: list[str] = ["latex", "expr"], + float_rounding: int = 6, + strict: bool = True, + allow_set_relation_comp: bool = False, + weighted: bool = False, loop_budget: int = 1, validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None = None, @@ -407,6 +418,14 @@ def __init__( Args: number_of_samples: Number of samples to generate and use for majority voting + match_type: type of match latex, expr (match only so far) + float_rounding: Number of decimal places to round floats to. Defaults to 6. + strict: Whether to enforce strict comparison mode. Defaults to True. + - In strict mode: Variables matter and sets are not comparable with tuples + - In non-strict mode: Variables are matched by position and sets can be compared with tuples + allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False. + - If True, set - relation comparison will be allowed in all cases. + - If False, set - relation comparison will be allowed only if the prediction is a set. loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. validate: Function to validate the results against requirements. If None, validation is provided later through setter. generate: Function to generate new model output thunks. If None, generate is provided later through setter. @@ -422,18 +441,43 @@ def __init__( requirements=requirements, ) self.number_of_samples = number_of_samples - self.answer_extraction_regex = answer_extraction_regex + self.match_types = match_types + self.float_rounding = float_rounding + self.strict = strict + self.allow_set_relation_comp = allow_set_relation_comp + self.weighted = weighted + + # Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183 + self.symmetric = False + + # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 + def compare_strings(self, gold: str, pred: str): + """Helper function to compare strings using the math extraction metrics""" + # Convert string match_types to ExtractionTarget objects + extraction_targets = [] + for match_type in self.match_types: + if match_type == "latex": + extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0)) + elif match_type == "expr": + extraction_targets.append(ExprExtractionConfig()) + + gold_parsed = parse(gold, extraction_targets) + pred_parsed = parse(pred, extraction_targets) + return verify( + gold_parsed, + pred_parsed, + float_rounding=self.float_rounding, + strict=self.strict, + allow_set_relation_comp=self.allow_set_relation_comp, + ) - def answer_extraction(self, response): - matches = re.findall(self.answer_extraction_regex, response, re.DOTALL) - if len(matches) > 0: - return matches[-1] # return the last match - else: - return "" + def maybe_apply_weighted(self, scr): + # TODO not implemented yet + if self.weighted: + weights = np.asarray([1.0 for _ in range(len(scr))]) + scr = scr * weights - def format_math(self, response): - # TODO implement - return response + return scr def sample( self, @@ -445,7 +489,7 @@ def sample( generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: - results = dict() + results = [] # Generate samples for i in range(self.number_of_samples): result = super().sample( @@ -457,20 +501,42 @@ def sample( validation_ctx=validation_ctx, ) if result.success: - output = str(result.result) - else: - output = result.sample_generations[0].value - - answer = self.answer_extraction(output) - answer = self.format_math(answer) - if answer in results: - results[answer].append(result) + results.append((str(result.result), result)) else: - results[answer] = [result] + results.append((result.sample_generations[0].value, result)) assert len(results) > 0 - # obtain majority voting answer - counts = Counter(results.keys()) - ans, cnt = counts.most_common(1)[0] - return results[ans][0] # return one of the MV answers + scr = [[0.0 for _ in range(len(results))] for _ in range(len(results))] + scr = np.asarray(scr) + for i in range(len(results)): + for j in range(len(results)): + if j == i: + scr[i][j] = 0.0 # self voting is 0. + continue + + # upper triangle + if j > i: + scr[i][j] = float( + self.compare_strings(results[i][0], results[j][0]) + ) + continue + + else: + if self.symmetric: + scr[i][j] = scr[j][i] + else: + scr[i][j] = float( + self.compare_strings(results[j][0], results[i][0]) + ) + continue + + # count votes + scr = scr.sum(axis=0) + + # Apply weights + scr = self.maybe_apply_weighted(scr) + + maxR = int(scr.argmax()) + + return results[maxR][1] # return one of the MV answers diff --git a/test/stdlib_basics/test_majority_voting/install.sh b/test/stdlib_basics/test_majority_voting/install.sh index 0ce37ba1..0f703bad 100755 --- a/test/stdlib_basics/test_majority_voting/install.sh +++ b/test/stdlib_basics/test_majority_voting/install.sh @@ -19,4 +19,5 @@ in-conda uv pip install vllm==0.10.0 in-conda uv pip install outlines # in-conda uv pip install unsloth in-conda uv pip install ipdb +in-conda uv pip install math-verify[antlr4_13_2] From dbdb518f2079f5403f8e8b8d6fefb078667897a2 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Tue, 16 Sep 2025 15:16:00 +0000 Subject: [PATCH 04/10] adjusted string comparison ref vs pred relation - fixed some type checking --- mellea/stdlib/sampling.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index da319357..4251b972 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -451,7 +451,7 @@ def __init__( self.symmetric = False # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 - def compare_strings(self, gold: str, pred: str): + def compare_strings(self, ref: str, pred: str): """Helper function to compare strings using the math extraction metrics""" # Convert string match_types to ExtractionTarget objects extraction_targets = [] @@ -461,7 +461,7 @@ def compare_strings(self, gold: str, pred: str): elif match_type == "expr": extraction_targets.append(ExprExtractionConfig()) - gold_parsed = parse(gold, extraction_targets) + gold_parsed = parse(ref, extraction_targets) pred_parsed = parse(pred, extraction_targets) return verify( gold_parsed, @@ -501,14 +501,19 @@ def sample( validation_ctx=validation_ctx, ) if result.success: - results.append((str(result.result), result)) + output = str(result.result) else: - results.append((result.sample_generations[0].value, result)) + # avoid type checker error + assert isinstance(result.sample_generations, list) + output = str(result.sample_generations[0].value) + + results.append((output, result)) assert len(results) > 0 - scr = [[0.0 for _ in range(len(results))] for _ in range(len(results))] - scr = np.asarray(scr) + scr = np.asarray( + [[0.0 for _ in range(len(results))] for _ in range(len(results))] + ) for i in range(len(results)): for j in range(len(results)): if j == i: @@ -516,9 +521,10 @@ def sample( continue # upper triangle + # For sample i compute votes against all j references if j > i: scr[i][j] = float( - self.compare_strings(results[i][0], results[j][0]) + self.compare_strings(results[j][0], results[i][0]) ) continue From 18c609d3d88011e3352eb9520aeafac7b7e784ca Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Thu, 18 Sep 2025 17:09:32 +0000 Subject: [PATCH 05/10] adds MBRD rougeL similarity - refactors MBRD base class --- mellea/stdlib/sampling.py | 195 ++++++++++++++---- pyproject.toml | 2 +- .../test_majority_voting/install.sh | 1 + .../test_majority_voting.py | 53 +++-- 4 files changed, 186 insertions(+), 65 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 4251b972..79cf32f6 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -10,6 +10,7 @@ import numpy as np import tqdm from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify +from rouge_score.rouge_scorer import RougeScorer # codespell:ignore from mellea import LinearContext from mellea.helpers.fancy_logger import FancyLogger @@ -387,12 +388,8 @@ def repair( return next_action -class MajorityVotingStrategyForMath(RejectionSamplingStrategy): +class BaseMBRDSampling(RejectionSamplingStrategy): number_of_samples: int - match_types: list[str] - float_rounding: int - strict: bool - allow_set_relation_comp: bool weighted: bool symmetric: bool @@ -400,10 +397,6 @@ def __init__( self, *, number_of_samples: int = 8, - match_types: list[str] = ["latex", "expr"], - float_rounding: int = 6, - strict: bool = True, - allow_set_relation_comp: bool = False, weighted: bool = False, loop_budget: int = 1, validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] @@ -418,14 +411,6 @@ def __init__( Args: number_of_samples: Number of samples to generate and use for majority voting - match_type: type of match latex, expr (match only so far) - float_rounding: Number of decimal places to round floats to. Defaults to 6. - strict: Whether to enforce strict comparison mode. Defaults to True. - - In strict mode: Variables matter and sets are not comparable with tuples - - In non-strict mode: Variables are matched by position and sets can be compared with tuples - allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False. - - If True, set - relation comparison will be allowed in all cases. - - If False, set - relation comparison will be allowed only if the prediction is a set. loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. validate: Function to validate the results against requirements. If None, validation is provided later through setter. generate: Function to generate new model output thunks. If None, generate is provided later through setter. @@ -441,35 +426,12 @@ def __init__( requirements=requirements, ) self.number_of_samples = number_of_samples - self.match_types = match_types - self.float_rounding = float_rounding - self.strict = strict - self.allow_set_relation_comp = allow_set_relation_comp self.weighted = weighted - - # Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183 self.symmetric = False - # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 - def compare_strings(self, ref: str, pred: str): - """Helper function to compare strings using the math extraction metrics""" - # Convert string match_types to ExtractionTarget objects - extraction_targets = [] - for match_type in self.match_types: - if match_type == "latex": - extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0)) - elif match_type == "expr": - extraction_targets.append(ExprExtractionConfig()) - - gold_parsed = parse(ref, extraction_targets) - pred_parsed = parse(pred, extraction_targets) - return verify( - gold_parsed, - pred_parsed, - float_rounding=self.float_rounding, - strict=self.strict, - allow_set_relation_comp=self.allow_set_relation_comp, - ) + @abc.abstractmethod + def compare_strings(self, ref: str, pred: str) -> float: + """This method is the abstract method for MBRD similarity metric.""" def maybe_apply_weighted(self, scr): # TODO not implemented yet @@ -546,3 +508,150 @@ def sample( maxR = int(scr.argmax()) return results[maxR][1] # return one of the MV answers + + +class MajorityVotingStrategyForMath(BaseMBRDSampling): + number_of_samples: int + match_types: list[str] + float_rounding: int + strict: bool + allow_set_relation_comp: bool + weighted: bool + symmetric: bool + + def __init__( + self, + *, + number_of_samples: int = 8, + float_rounding: int = 6, + strict: bool = True, + allow_set_relation_comp: bool = False, + weighted: bool = False, + loop_budget: int = 1, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + requirements: list[Requirement] | None = None, + ): + """Initialize a new instance of the class with default parameters. + + Args: + number_of_samples: Number of samples to generate and use for majority voting + float_rounding: Number of decimal places to round floats to. Defaults to 6. + strict: Whether to enforce strict comparison mode. Defaults to True. + - In strict mode: Variables matter and sets are not comparable with tuples + - In non-strict mode: Variables are matched by position and sets can be compared with tuples + allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False. + - If True, set - relation comparison will be allowed in all cases. + - If False, set - relation comparison will be allowed only if the prediction is a set. + loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. + validate: Function to validate the results against requirements. If None, validation is provided later through setter. + generate: Function to generate new model output thunks. If None, generate is provided later through setter. + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. + + Raises: + AssertionError: If loop_budget is not greater than 0. + """ + super().__init__( + number_of_samples=number_of_samples, + weighted=weighted, + loop_budget=loop_budget, + validate=validate, + generate=generate, + requirements=requirements, + ) + self.number_of_samples = number_of_samples + # match_type: type of match latex, expr (match only so far) + # - For math use "latex" or "expr" or both + # - For general text similarity use "rougel" + MATCH_TYPES = ["latex", "axpr"] + self.match_types = MATCH_TYPES + self.float_rounding = float_rounding + self.strict = strict + self.allow_set_relation_comp = allow_set_relation_comp + self.weighted = weighted + + # Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183 + self.symmetric = False + + # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 + def compare_strings(self, ref: str, pred: str): + """Helper function to compare strings using the math extraction metrics""" + # Convert string match_types to ExtractionTarget objects + extraction_targets = [] + for match_type in self.match_types: + if match_type == "latex": + extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0)) + elif match_type == "expr": + extraction_targets.append(ExprExtractionConfig()) + + gold_parsed = parse(ref, extraction_targets) + pred_parsed = parse(pred, extraction_targets) + return verify( + gold_parsed, + pred_parsed, + float_rounding=self.float_rounding, + strict=self.strict, + allow_set_relation_comp=self.allow_set_relation_comp, + ) + + +class MBRDRougeLStrategy(BaseMBRDSampling): + number_of_samples: int + match_types: list[str] + weighted: bool + symmetric: bool + scorer: RougeScorer + + def __init__( + self, + *, + number_of_samples: int = 8, + weighted: bool = False, + loop_budget: int = 1, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + requirements: list[Requirement] | None = None, + ): + """Initialize a new instance of the class with default parameters. + + Args: + number_of_samples: Number of samples to generate and use for majority voting + match_type: type of match latex, expr (match only so far) + - For math use "latex" or "expr" or both + - For general text similarity use "rougel" + loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. + validate: Function to validate the results against requirements. If None, validation is provided later through setter. + generate: Function to generate new model output thunks. If None, generate is provided later through setter. + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. + + Raises: + AssertionError: If loop_budget is not greater than 0. + """ + super().__init__( + number_of_samples=number_of_samples, + weighted=weighted, + loop_budget=loop_budget, + validate=validate, + generate=generate, + requirements=requirements, + ) + self.number_of_samples = number_of_samples + self.match_types = ["rougeL"] + self.weighted = weighted + self.symmetric = True + self.scorer = RougeScorer(self.match_types, use_stemmer=True) + + # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 + def compare_strings(self, ref: str, pred: str): + """Helper function to compare strings using the math extraction metrics""" + + scr = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure + return scr diff --git a/pyproject.toml b/pyproject.toml index 7e1989ef..d9d2a255 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,7 +158,7 @@ combine-as-imports = true split-on-trailing-comma = false [tool.codespell] -ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd' +ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,rouge,Rouge' check-filenames = true check-hidden = false regex = "(? str: - # breakpoint() - # email_candidate = m.instruct( - # "Write an email to {{name}} using the notes following: {{notes}}.", - # # requirements=requirements, - # strategy=RejectionSamplingStrategy(loop_budget=5), - # user_variables={"name": name, "notes": notes}, - # return_sampling_results=True, - # ) result = self.m.instruct( prompt, # requirements=requirements, @@ -100,8 +82,37 @@ def test_majority_voting(self): output = result.sample_generations[0].value print(output) + assert output - # assert gen_tok_cnt <= 2 * THINK_MAX_TOKENS + + def test_MBRDRougeL(self): + + requirements = [ + req("The email should have a salutation"), # == r1 + req( + "Use only lower-case letters", + validation_fn=simple_validate(lambda x: x.lower() == x), + ), # == r2 + check("Do not mention purple elephants."), # == r3 + ] + + name = "Olivia" + notes = "Olivia helped the lab over the last few weeks by organizing intern events, advertising the speaker series, and handling issues with snack delivery." + email_candidate = self.m.instruct( + "Write an email to {{name}} using the notes following: {{notes}}.", + requirements=requirements, + strategy=MBRDRougeLStrategy(loop_budget=1), + user_variables={"name": name, "notes": notes}, + return_sampling_results=True, + ) + + if email_candidate.success: + output = str(email_candidate.result) + else: + output = email_candidate.sample_generations[0].value + print(output) + + assert output if __name__ == "__main__": pytest.main(["-s", __file__]) From 257754666774ecd4721d2f747d6cac8570ef20ec Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Thu, 25 Sep 2025 15:16:59 -0400 Subject: [PATCH 06/10] fixes tests removes dependency on vLLM and uses default ollama --- .../test_majority_voting.py | 55 +------------------ .../test_majority_voting/.gitignore | 2 - .../test_majority_voting/README.md | 22 -------- .../test_majority_voting/environment.yml | 7 --- .../exec_sampling_test.sh | 10 ---- .../test_majority_voting/install.sh | 24 -------- .../test_majority_voting/run_test.sh | 24 -------- .../test_majority_voting/serve.sh | 16 ------ .../test_majority_voting/set_variables.sh | 8 --- 9 files changed, 3 insertions(+), 165 deletions(-) rename test/stdlib_basics/{test_majority_voting => }/test_majority_voting.py (58%) delete mode 100644 test/stdlib_basics/test_majority_voting/.gitignore delete mode 100644 test/stdlib_basics/test_majority_voting/README.md delete mode 100644 test/stdlib_basics/test_majority_voting/environment.yml delete mode 100644 test/stdlib_basics/test_majority_voting/exec_sampling_test.sh delete mode 100755 test/stdlib_basics/test_majority_voting/install.sh delete mode 100755 test/stdlib_basics/test_majority_voting/run_test.sh delete mode 100755 test/stdlib_basics/test_majority_voting/serve.sh delete mode 100644 test/stdlib_basics/test_majority_voting/set_variables.sh diff --git a/test/stdlib_basics/test_majority_voting/test_majority_voting.py b/test/stdlib_basics/test_majority_voting.py similarity index 58% rename from test/stdlib_basics/test_majority_voting/test_majority_voting.py rename to test/stdlib_basics/test_majority_voting.py index e35fc05a..88fb014c 100644 --- a/test/stdlib_basics/test_majority_voting/test_majority_voting.py +++ b/test/stdlib_basics/test_majority_voting.py @@ -12,62 +12,13 @@ class TestMajorityVoting: - MODEL_ID = os.environ.get("LOCAL_TEST_MODEL", None) - if MODEL_ID is None: - raise RuntimeError(f"Must set environment variable `LOCAL_TEST_MODEL` to a HF model id") - - # Local testing mode - if MODEL_ID == "ibm-granite/granite-4.0-tiny-preview": - MODEL_ID = IBM_GRANITE_4_TINY_PREVIEW_7B - - elif MODEL_ID == "unsloth/Llama-3.2-1B": - MODEL_ID = META_LLAMA_3_2_1B - - else: - raise RuntimeError(f"Unsupported model-id:{MODEL_ID}") - - model_id = "ibm-granite/granite-4.0-tiny-preview" - backend = OpenAIBackend( - model_id=MODEL_ID, - formatter=TemplateFormatter(model_id=MODEL_ID), - base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:8000')}/v1", - api_key="ollama", - ) - - m = MelleaSession(backend, ctx=SimpleContext()) - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID.hf_model_name, trust_remote_code=True) - - - def prepare_prmpt_for_math(self, query): - # Preparing prompt for math reasoning tasks - system_prompt = None # Use default of chat template - prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}." - - if prompt_suffix: - query += prompt_suffix - - msg = [] - if system_prompt is not None: - msg.append({"role": "system", "content": system_prompt}) - - msg.append({"role": "user", "content": query}) - if self.tokenizer.chat_template is None: - raise RuntimeError(f"No explicit chat template is defined for model-id: ") - - else: - prompt = self.tokenizer.apply_chat_template( - msg, - tokenize=False, - thinking=True, - add_generation_prompt=True, - ) - - return prompt + m = mellea.start_session(ctx=SimpleContext()) def test_majority_voting_for_math(self): query = "Compute 1+1" - prompt = self.prepare_prmpt_for_math(query) + prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}." + prompt = query + prompt_suffix result = self.m.instruct( prompt, diff --git a/test/stdlib_basics/test_majority_voting/.gitignore b/test/stdlib_basics/test_majority_voting/.gitignore deleted file mode 100644 index baf1455b..00000000 --- a/test/stdlib_basics/test_majority_voting/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -vllm.err -vllm.log diff --git a/test/stdlib_basics/test_majority_voting/README.md b/test/stdlib_basics/test_majority_voting/README.md deleted file mode 100644 index c904b0d6..00000000 --- a/test/stdlib_basics/test_majority_voting/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Test for OpenAI API served by VLLM - -## Requirement - -anaconda / miniconda / miniforge. - -Make sure to run the test with multiple cores available (e.g. in a cloud instance / cluster job). -Although you may think 1 core is enough, -vllm could get stuck due to deadlock if so. - -## Installation - -Run the `install.sh` script, which needs to be done only once. -The script creates a new conda environment named "mellea_tbf" only for the purposes of testing or contributing to the think budget-forcing feature. - -Run `./install.sh` - -## Testing - -``` shell -./run_test.sh -``` diff --git a/test/stdlib_basics/test_majority_voting/environment.yml b/test/stdlib_basics/test_majority_voting/environment.yml deleted file mode 100644 index 626ad888..00000000 --- a/test/stdlib_basics/test_majority_voting/environment.yml +++ /dev/null @@ -1,7 +0,0 @@ - -name: mellea_mv -channels: - - conda-forge -dependencies: - - python=3.12 # note: at the time of writing, xformer (< vllm) has a broken wheel for 3.13. https://github.com/facebookresearch/xformers/issues/740#issuecomment-2753869337 - - uv diff --git a/test/stdlib_basics/test_majority_voting/exec_sampling_test.sh b/test/stdlib_basics/test_majority_voting/exec_sampling_test.sh deleted file mode 100644 index d543632c..00000000 --- a/test/stdlib_basics/test_majority_voting/exec_sampling_test.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -source set_variables.sh - -eval "$(conda shell.bash hook)" -conda activate $ENV_NAME - -export LOCAL_TEST_MODEL - -python test_majority_voting.py diff --git a/test/stdlib_basics/test_majority_voting/install.sh b/test/stdlib_basics/test_majority_voting/install.sh deleted file mode 100755 index 15194d16..00000000 --- a/test/stdlib_basics/test_majority_voting/install.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -xe - -source set_variables.sh - -conda env remove -y -n $ENV_NAME || true -conda env create -f $(readlink -f $(dirname $0))/environment.yml - -in-conda (){ - conda run -n $ENV_NAME $@ -} - - -cd ../../../ -in-conda uv pip install -e . -cd - -in-conda uv pip install pre-commit -in-conda uv pip install pytest -in-conda uv pip install vllm==0.10.0 -in-conda uv pip install outlines -# in-conda uv pip install unsloth -in-conda uv pip install ipdb -in-conda uv pip install math-verify[antlr4_13_2] -in-conda uv pip install rouge-score - diff --git a/test/stdlib_basics/test_majority_voting/run_test.sh b/test/stdlib_basics/test_majority_voting/run_test.sh deleted file mode 100755 index 9e17b492..00000000 --- a/test/stdlib_basics/test_majority_voting/run_test.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash - -source set_variables.sh - -eval "$(conda shell.bash hook)" -conda activate $ENV_NAME - -rm $VLLM_LOG $VLLM_ERR - -bash ./serve.sh & -VLLM_PID=$! - -trap "kill -SIGINT $VLLM_PID ; wait" EXIT - -while sleep 1 ; do - if grep -q "Application startup complete." $VLLM_ERR - then - break - fi -done - -bash exec_sampling_test.sh - - diff --git a/test/stdlib_basics/test_majority_voting/serve.sh b/test/stdlib_basics/test_majority_voting/serve.sh deleted file mode 100755 index a890a211..00000000 --- a/test/stdlib_basics/test_majority_voting/serve.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -source set_variables.sh -eval "$(conda shell.bash hook)" -conda activate $ENV_NAME -export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True - -echo "launching a vllm server. Logs are found in $(readlink -ef $(dirname $0))/vllm.log" - # At the time of writing this code, Granite 4.4 vLLM serving did not support prefix-caching - # --enable-prefix-caching \ -vllm serve $LOCAL_TEST_MODEL \ - --dtype bfloat16 \ - > $VLLM_LOG \ - 2> $VLLM_ERR - - diff --git a/test/stdlib_basics/test_majority_voting/set_variables.sh b/test/stdlib_basics/test_majority_voting/set_variables.sh deleted file mode 100644 index 6971f487..00000000 --- a/test/stdlib_basics/test_majority_voting/set_variables.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -PYTHONBREAKPOINT="ipdb.set_trace" -LOCAL_TEST_MODEL="ibm-granite/granite-4.0-tiny-preview" -ENV_NAME=mellea_mv -DIR=$(readlink -ef $(dirname $0)) -VLLM_LOG=$DIR/vllm.log -VLLM_ERR=$DIR/vllm.err From 584044711a85293cdd15913159712b72eda792e9 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Thu, 25 Sep 2025 21:49:41 -0400 Subject: [PATCH 07/10] minor fix to pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ec636e1f..166c5c53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -160,7 +160,7 @@ combine-as-imports = true split-on-trailing-comma = false [tool.codespell] -ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,rouge,Rouge' +ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot,rouge,Rouge' check-filenames = true check-hidden = false regex = "(? Date: Tue, 30 Sep 2025 16:11:59 -0400 Subject: [PATCH 08/10] allow self-voting, and default symmetry, removed redundant variables in devirved classes --- mellea/stdlib/sampling.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 3d7c9bad..13545736 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -669,7 +669,7 @@ def __init__( ) self.number_of_samples = number_of_samples self.weighted = weighted - self.symmetric = False + self.symmetric = True @abc.abstractmethod def compare_strings(self, ref: str, pred: str) -> float: @@ -728,7 +728,7 @@ async def sample( for i in range(len(results)): for j in range(len(results)): if j == i: - scr[i][j] = 0.0 # self voting is 0. + scr[i][j] = 1.0 # self voting is 1. continue # upper triangle @@ -765,8 +765,6 @@ class MajorityVotingStrategyForMath(BaseMBRDSampling): float_rounding: int strict: bool allow_set_relation_comp: bool - weighted: bool - symmetric: bool def __init__( self, @@ -821,10 +819,9 @@ def __init__( self.float_rounding = float_rounding self.strict = strict self.allow_set_relation_comp = allow_set_relation_comp - self.weighted = weighted # Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183 - self.symmetric = False + self.symmetric = True # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 def compare_strings(self, ref: str, pred: str): @@ -851,10 +848,7 @@ def compare_strings(self, ref: str, pred: str): class MBRDRougeLStrategy(BaseMBRDSampling): - number_of_samples: int match_types: list[str] - weighted: bool - symmetric: bool scorer: RougeScorer def __init__( @@ -894,9 +888,7 @@ def __init__( generate=generate, requirements=requirements, ) - self.number_of_samples = number_of_samples self.match_types = ["rougeL"] - self.weighted = weighted self.symmetric = True self.scorer = RougeScorer(self.match_types, use_stemmer=True) From 7709b126668bef32c4d72258205cbe37262f9796 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Thu, 2 Oct 2025 13:31:47 -0400 Subject: [PATCH 09/10] Added special ficture for tests to run on github action runners --- test/stdlib_basics/test_majority_voting.py | 108 ++++++++++++--------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/test/stdlib_basics/test_majority_voting.py b/test/stdlib_basics/test_majority_voting.py index 2d4663b0..8d5d4ca9 100644 --- a/test/stdlib_basics/test_majority_voting.py +++ b/test/stdlib_basics/test_majority_voting.py @@ -1,5 +1,5 @@ -import mellea -from mellea.stdlib.base import SimpleContext +from mellea.backends import ModelOption +from mellea import start_session, MelleaSession from mellea.stdlib.requirement import check, req, simple_validate from mellea.stdlib.sampling.majority_voting import ( MBRDRougeLStrategy, @@ -8,55 +8,69 @@ import pytest -class TestMajorityVoting: - - def test_majority_voting_for_math(self): - m = mellea.start_session(ctx=SimpleContext()) - query = "Compute 1+1" - prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}." - prompt = query + prompt_suffix - - result = m.instruct( - prompt, - strategy=MajorityVotingStrategyForMath(number_of_samples=8, loop_budget=1), - return_sampling_results=True, +@pytest.fixture(scope="module") +def m_session(gh_run): + if gh_run == 1: + m = start_session( + "ollama", + model_id="llama3.2:1b", + model_options={ModelOption.MAX_NEW_TOKENS: 5}, ) - if result.success: - output = str(result.result) - else: - output = result.sample_generations[0].value - - print(output) - assert output - - def test_MBRDRougeL(self): - m = mellea.start_session(ctx=SimpleContext()) - requirements = [ - req("The email should have a salutation"), # == r1 - req( - "Use only lower-case letters", - validation_fn=simple_validate(lambda x: x.lower() == x), - ), # == r2 - check("Do not mention purple elephants."), # == r3 - ] - - name = "Olivia" - notes = "Olivia helped the lab over the last few weeks by organizing intern events, advertising the speaker series, and handling issues with snack delivery." - email_candidate = m.instruct( - "Write an email to {{name}} using the notes following: {{notes}}.", - requirements=requirements, - strategy=MBRDRougeLStrategy(number_of_samples=8, loop_budget=1), - user_variables={"name": name, "notes": notes}, - return_sampling_results=True, + else: + m = start_session( + "ollama", + model_id="llama3.2:1b", ) + yield m + del m + + +def test_majority_voting_for_math(m_session: MelleaSession): + query = "Compute 1+1" + prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}." + prompt = query + prompt_suffix + + result = m_session.instruct( + prompt, + strategy=MajorityVotingStrategyForMath(number_of_samples=8, loop_budget=1), + return_sampling_results=True, + ) + if result.success: + output = str(result.result) + else: + output = result.sample_generations[0].value + + print(output) + assert output + + +def test_MBRDRougeL(m_session: MelleaSession): + requirements = [ + req("The email should have a salutation"), # == r1 + req( + "Use only lower-case letters", + validation_fn=simple_validate(lambda x: x.lower() == x), + ), # == r2 + check("Do not mention purple elephants."), # == r3 + ] + + name = "Olivia" + notes = "Olivia helped the lab over the last few weeks by organizing intern events, advertising the speaker series, and handling issues with snack delivery." + email_candidate = m_session.instruct( + "Write an email to {{name}} using the notes following: {{notes}}.", + requirements=requirements, + strategy=MBRDRougeLStrategy(number_of_samples=8, loop_budget=1), + user_variables={"name": name, "notes": notes}, + return_sampling_results=True, + ) - if email_candidate.success: - output = str(email_candidate.result) - else: - output = email_candidate.sample_generations[0].value + if email_candidate.success: + output = str(email_candidate.result) + else: + output = email_candidate.sample_generations[0].value - print(output) - assert output + print(output) + assert output if __name__ == "__main__": From 59a9398f6a63ec46220170bc5a991d766647d2c3 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Fri, 3 Oct 2025 13:15:41 -0400 Subject: [PATCH 10/10] fix: minor tweaks to get automation to pass, remove redundant mellea checks, and improve readability --- mellea/stdlib/sampling/majority_voting.py | 137 +++++++++++++-------- pyproject.toml | 2 + test/stdlib_basics/test_majority_voting.py | 16 +-- uv.lock | 86 ++++++++++++- 4 files changed, 177 insertions(+), 64 deletions(-) diff --git a/mellea/stdlib/sampling/majority_voting.py b/mellea/stdlib/sampling/majority_voting.py index 8d4fe15b..8ba99798 100644 --- a/mellea/stdlib/sampling/majority_voting.py +++ b/mellea/stdlib/sampling/majority_voting.py @@ -1,5 +1,7 @@ +"""Sampling Strategies for Minimum Bayes Risk Decoding (MBRD).""" + import abc -from asyncio import TaskGroup # type: ignore[attr-defined] +import asyncio import numpy as np from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify @@ -12,6 +14,8 @@ class BaseMBRDSampling(RejectionSamplingStrategy): + """Abstract Minimum Bayes Risk Decoding (MBRD) Sampling Strategy.""" + number_of_samples: int weighted: bool symmetric: bool @@ -24,10 +28,16 @@ def __init__( loop_budget: int = 1, requirements: list[Requirement] | None = None, ): - """Initialize a new instance of the class with default parameters. + """Initialize a new abstract Minimum Bayes Risk Decoding (MBRD) Sampling Strategy with default parameters. + + Inherits from RejectionSamplingStrategy. Will generate up to loop_budget x number_of_samples requests. If no + requirements are provided here or in sample(...), will only generate number_of_samples requests. + + Classes that inherit from this must implement the `compare_strings` function. Args: number_of_samples: Number of samples to generate and use for majority voting + weighted: Not Implemented. If True, weights the score before getting the final majority vote loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. @@ -43,8 +53,9 @@ def __init__( def compare_strings(self, ref: str, pred: str) -> float: """This method is the abstract method for MBRD similarity metric.""" - def maybe_apply_weighted(self, scr): - # TODO not implemented yet + def maybe_apply_weighted(self, scr: np.ndarray): + """Applies weights if self.weighted is True. Not Implemented.""" + # TODO: not implemented yet if self.weighted: weights = np.asarray([1.0 for _ in range(len(scr))]) scr = scr * weights @@ -64,39 +75,49 @@ async def sample( tool_calls: bool = False, show_progress: bool = True, ) -> SamplingResult: + """Samples using majority voting. + + Args: + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + backend: The backend used for generating samples. + requirements: List of requirements to test against (merged with global requirements). + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. + format: output format for structured outputs; ignored for this sampling strategy. + model_options: model options to pass to the backend during generation / validation. + tool_calls: True if tool calls should be used during this sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + + Returns: + SamplingResult: A result object indicating the success or failure of the sampling process. + """ # execute sampling concurrently - tasks = [] - async with TaskGroup() as tg: - for i in range(self.number_of_samples): - task = tg.create_task( - super().sample( - action, - context, - backend, - requirements, - validation_ctx=validation_ctx, - model_options=model_options, - tool_calls=tool_calls, - show_progress=show_progress, - ) + tasks: list[asyncio.Task[SamplingResult]] = [] + for i in range(self.number_of_samples): + task = asyncio.create_task( + super().sample( + action, + context, + backend, + requirements, + validation_ctx=validation_ctx, + model_options=model_options, + tool_calls=tool_calls, + show_progress=show_progress, ) - tasks.append(task) + ) + tasks.append(task) - # collect results - results = [] - for task in tasks: - result = task.result() - if result.success: - output = str(result.result) - else: - # avoid type checker error - assert isinstance(result.sample_generations, list) - output = str(result.sample_generations[0].value) + sampling_results = await asyncio.gather(*tasks) + # collect results + results: list[tuple[str, SamplingResult]] = [] + for result in sampling_results: + output = str(result.result) results.append((output, result)) - assert len(results) > 0 + # Create an array of len(results) x len(results) initialized to 0.0. scr = np.asarray( [[0.0 for _ in range(len(results))] for _ in range(len(results))] ) @@ -124,17 +145,19 @@ async def sample( continue # count votes - scr = scr.sum(axis=0) + summed_scr: np.ndarray = scr.sum(axis=0) # Apply weights - scr = self.maybe_apply_weighted(scr) + weighed_scr = self.maybe_apply_weighted(summed_scr) - maxR = int(scr.argmax()) + maxR = int(weighed_scr.argmax()) return results[maxR][1] # return one of the MV answers class MajorityVotingStrategyForMath(BaseMBRDSampling): + """MajorityVoting Sampling Strategy for Math Expressions.""" + number_of_samples: int match_types: list[str] float_rounding: int @@ -152,7 +175,10 @@ def __init__( loop_budget: int = 1, requirements: list[Requirement] | None = None, ): - """Initialize a new instance of the class with default parameters. + """Initialize a new instance of MajorityVoting Sampling Strategy for Math with default parameters. + + Will generate up to loop_budget x number_of_samples requests. If no + requirements are provided here or in sample(...), will only generate number_of_samples requests. Args: number_of_samples: Number of samples to generate and use for majority voting @@ -163,6 +189,7 @@ def __init__( allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False. - If True, set - relation comparison will be allowed in all cases. - If False, set - relation comparison will be allowed only if the prediction is a set. + weighted: Not Implemented. If True, weights the score before getting the final majority vote loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. @@ -189,8 +216,8 @@ def __init__( self.symmetric = True # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 - def compare_strings(self, ref: str, pred: str): - """Helper function to compare strings using the math extraction metrics""" + def compare_strings(self, ref: str, pred: str) -> float: + """Helper function to compare strings using the math extraction metrics.""" # Convert string match_types to ExtractionTarget objects extraction_targets = [] for match_type in self.match_types: @@ -200,19 +227,23 @@ def compare_strings(self, ref: str, pred: str): extraction_targets.append(ExprExtractionConfig()) # NOTE: Math-Verify parse and verify functions don't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None - gold_parsed = parse(ref, extraction_targets, parsing_timeout=None) - pred_parsed = parse(pred, extraction_targets, parsing_timeout=None) - return verify( - gold_parsed, - pred_parsed, - float_rounding=self.float_rounding, - strict=self.strict, - allow_set_relation_comp=self.allow_set_relation_comp, - timeout_seconds=None, + gold_parsed = parse(ref, extraction_targets, parsing_timeout=None) # type: ignore + pred_parsed = parse(pred, extraction_targets, parsing_timeout=None) # type: ignore + return float( + verify( + gold_parsed, + pred_parsed, + float_rounding=self.float_rounding, + strict=self.strict, + allow_set_relation_comp=self.allow_set_relation_comp, + timeout_seconds=None, + ) ) class MBRDRougeLStrategy(BaseMBRDSampling): + """Sampling Strategy that uses RougeL to compute symbol-level distances for majority voting.""" + match_types: list[str] scorer: RougeScorer @@ -224,13 +255,14 @@ def __init__( loop_budget: int = 1, requirements: list[Requirement] | None = None, ): - """Initialize a new instance of the class with default parameters. + """Initialize a new instance of MBRDRougeL Sampling Strategy with default parameters. + + Will generate up to loop_budget x number_of_samples requests. If no + requirements are provided here or in sample(...), will only generate number_of_samples requests. Args: number_of_samples: Number of samples to generate and use for majority voting - match_type: type of match latex, expr (match only so far) - - For math use "latex" or "expr" or both - - For general text similarity use "rougel" + weighted: Not Implemented. If True, weights the score before getting the final majority vote loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. @@ -248,8 +280,7 @@ def __init__( self.scorer = RougeScorer(self.match_types, use_stemmer=True) # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 - def compare_strings(self, ref: str, pred: str): - """Helper function to compare strings using the math extraction metrics""" - - scr = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure + def compare_strings(self, ref: str, pred: str) -> float: + """Helper function to compare strings using the math extraction metrics.""" + scr: float = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure return scr diff --git a/pyproject.toml b/pyproject.toml index 9081e4cf..1716ba48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ dependencies = [ "mistletoe>=1.4.0", "huggingface-hub>=0.33.4", "pillow", + "math_verify", # Needed for Majority Voting Sampling Strategies. + "rouge_score" # Needed for Majority Voting Sampling Strategies. ] [project.scripts] diff --git a/test/stdlib_basics/test_majority_voting.py b/test/stdlib_basics/test_majority_voting.py index 8d5d4ca9..05acda00 100644 --- a/test/stdlib_basics/test_majority_voting.py +++ b/test/stdlib_basics/test_majority_voting.py @@ -7,6 +7,8 @@ ) import pytest +from mellea.stdlib.sampling.types import SamplingResult + @pytest.fixture(scope="module") def m_session(gh_run): @@ -35,10 +37,7 @@ def test_majority_voting_for_math(m_session: MelleaSession): strategy=MajorityVotingStrategyForMath(number_of_samples=8, loop_budget=1), return_sampling_results=True, ) - if result.success: - output = str(result.result) - else: - output = result.sample_generations[0].value + output = str(result.result) print(output) assert output @@ -56,18 +55,15 @@ def test_MBRDRougeL(m_session: MelleaSession): name = "Olivia" notes = "Olivia helped the lab over the last few weeks by organizing intern events, advertising the speaker series, and handling issues with snack delivery." - email_candidate = m_session.instruct( + email_candidate: SamplingResult = m_session.instruct( "Write an email to {{name}} using the notes following: {{notes}}.", - requirements=requirements, + requirements=requirements, # type: ignore strategy=MBRDRougeLStrategy(number_of_samples=8, loop_budget=1), user_variables={"name": name, "notes": notes}, return_sampling_results=True, ) - if email_candidate.success: - output = str(email_candidate.result) - else: - output = email_candidate.sample_generations[0].value + output = str(email_candidate.result) print(output) assert output diff --git a/uv.lock b/uv.lock index 65282d95..430fe83d 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -16,6 +16,15 @@ resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", ] +[[package]] +name = "absl-py" +version = "2.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" }, +] + [[package]] name = "accelerate" version = "1.10.1" @@ -193,6 +202,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/53/18/a56e2fe47b259bb52201093a3a9d4a32014f9d85071ad07e9d60600890ca/ansicolors-1.1.8-py2.py3-none-any.whl", hash = "sha256:00d2dde5a675579325902536738dd27e4fac1fd68f773fe36c21044eb559e187", size = 13847, upload-time = "2017-06-02T21:22:12.67Z" }, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/33/5f/2cdf6f7aca3b20d3f316e9f505292e1f256a32089bd702034c29ebde6242/antlr4_python3_runtime-4.13.2.tar.gz", hash = "sha256:909b647e1d2fc2b70180ac586df3933e38919c85f98ccc656a96cd3f25ef3916", size = 117467, upload-time = "2024-08-03T19:00:12.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/03/a851e84fcbb85214dc637b6378121ef9a0dd61b4c65264675d8a5c9b1ae7/antlr4_python3_runtime-4.13.2-py3-none-any.whl", hash = "sha256:fe3835eb8d33daece0e799090eda89719dbccee7aa39ef94eed3818cafa5a7e8", size = 144462, upload-time = "2024-08-03T19:00:11.134Z" }, +] + [[package]] name = "anyio" version = "4.10.0" @@ -1753,6 +1771,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" }, ] +[[package]] +name = "joblib" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, +] + [[package]] name = "json5" version = "0.12.1" @@ -2070,6 +2097,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3e/76/d661ea2e529c3d464f9efd73f9ac31626b45279eb4306e684054ea20e3d4/latex2mathml-3.78.1-py3-none-any.whl", hash = "sha256:f089b6d75e85b937f99693c93e8c16c0804008672c3dd2a3d25affd36f238100", size = 73892, upload-time = "2025-08-29T23:34:21.98Z" }, ] +[[package]] +name = "latex2sympy2-extended" +version = "1.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "sympy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/de/472f9115c14c6f6d8a5889cabe3418283d708bde62ce00402c29441deed4/latex2sympy2_extended-1.10.2.tar.gz", hash = "sha256:41a517ffcc5a140e910a7d1646ce6ff440817e5f9d48fc8279d88bd0925bc389", size = 206188, upload-time = "2025-07-02T15:26:06.225Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/60/dfbbf40e3a371388c0e03ff65b01319b7d4023e883df6d7261125772ffdc/latex2sympy2_extended-1.10.2-py3-none-any.whl", hash = "sha256:f910442c5b02a466c1046f47d05cc5285181068b882399281f30102715337fb7", size = 207855, upload-time = "2025-07-02T15:26:04.88Z" }, +] + [[package]] name = "lazy-loader" version = "0.4" @@ -2278,6 +2318,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739, upload-time = "2024-10-18T15:21:42.784Z" }, ] +[[package]] +name = "math-verify" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "latex2sympy2-extended" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/b5/b1db6fa6b6c28ebbe1889ee11a4703a72a2ca7750ec415f4559c758cf01a/math_verify-0.8.0.tar.gz", hash = "sha256:3295e0adb94bfe553ff6e3189c44f1916a85aa24ab5d1900f2086a706e28f7c4", size = 60191, upload-time = "2025-07-02T15:52:07.209Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/9f/59979f699b5c97334298f1295bc9fcdc9904d98d2276479bffff863d23b1/math_verify-0.8.0-py3-none-any.whl", hash = "sha256:31ca651296d817a9bb3fd58ca1fd0d192dcea709b1e5ecf2d0a4514c16f89087", size = 29994, upload-time = "2025-07-02T15:52:05.023Z" }, +] + [[package]] name = "matplotlib-inline" version = "0.1.7" @@ -2319,12 +2371,14 @@ dependencies = [ { name = "huggingface-hub" }, { name = "jinja2" }, { name = "json5" }, + { name = "math-verify" }, { name = "mistletoe" }, { name = "ollama" }, { name = "openai" }, { name = "pillow" }, { name = "pydantic" }, { name = "requests" }, + { name = "rouge-score" }, { name = "typer" }, { name = "types-requests" }, { name = "types-tqdm" }, @@ -2401,6 +2455,7 @@ requires-dist = [ { name = "jinja2" }, { name = "json5" }, { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.76" }, + { name = "math-verify" }, { name = "mellea", extras = ["watsonx", "docling", "hf", "litellm"], marker = "extra == 'all'" }, { name = "mistletoe", specifier = ">=1.4.0" }, { name = "ollama", specifier = ">=0.5.1" }, @@ -2410,6 +2465,7 @@ requires-dist = [ { name = "pillow" }, { name = "pydantic" }, { name = "requests", specifier = ">=2.32.3" }, + { name = "rouge-score" }, { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53.2" }, { name = "trl", marker = "extra == 'hf'", specifier = ">=0.19.0" }, { name = "typer" }, @@ -2830,6 +2886,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" }, ] +[[package]] +name = "nltk" +version = "3.9.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "joblib" }, + { name = "regex" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f9/76/3a5e4312c19a028770f86fd7c058cf9f4ec4321c6cf7526bab998a5b683c/nltk-3.9.2.tar.gz", hash = "sha256:0f409e9b069ca4177c1903c3e843eef90c7e92992fa4931ae607da6de49e1419", size = 2887629, upload-time = "2025-10-01T07:19:23.764Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/90/81ac364ef94209c100e12579629dc92bf7a709a84af32f8c551b02c07e94/nltk-3.9.2-py3-none-any.whl", hash = "sha256:1e209d2b3009110635ed9709a67a1a3e33a10f799490fa71cf4bec218c11c88a", size = 1513404, upload-time = "2025-10-01T07:19:21.648Z" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -4574,6 +4645,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/53/97/d2cbbaa10c9b826af0e10fdf836e1bf344d9f0abb873ebc34d1f49642d3f/roman_numerals_py-3.1.0-py3-none-any.whl", hash = "sha256:9da2ad2fb670bcf24e81070ceb3be72f6c11c440d73bd579fbeca1e9f330954c", size = 7742, upload-time = "2025-02-22T07:34:52.422Z" }, ] +[[package]] +name = "rouge-score" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "nltk" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz", hash = "sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04", size = 17400, upload-time = "2022-07-22T22:46:22.909Z" } + [[package]] name = "rpds-py" version = "0.27.1"