|
| 1 | +"""Sampling Strategies for Minimum Bayes Risk Decoding (MBRD).""" |
| 2 | + |
| 3 | +import abc |
| 4 | +import asyncio |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify |
| 8 | +from rouge_score.rouge_scorer import RougeScorer # codespell:ignore |
| 9 | + |
| 10 | +from mellea.backends import Backend, BaseModelSubclass |
| 11 | +from mellea.stdlib.requirement import Requirement |
| 12 | +from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult |
| 13 | +from mellea.stdlib.sampling.base import Component, Context |
| 14 | + |
| 15 | + |
| 16 | +class BaseMBRDSampling(RejectionSamplingStrategy): |
| 17 | + """Abstract Minimum Bayes Risk Decoding (MBRD) Sampling Strategy.""" |
| 18 | + |
| 19 | + number_of_samples: int |
| 20 | + weighted: bool |
| 21 | + symmetric: bool |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + *, |
| 26 | + number_of_samples: int = 8, |
| 27 | + weighted: bool = False, |
| 28 | + loop_budget: int = 1, |
| 29 | + requirements: list[Requirement] | None = None, |
| 30 | + ): |
| 31 | + """Initialize a new abstract Minimum Bayes Risk Decoding (MBRD) Sampling Strategy with default parameters. |
| 32 | +
|
| 33 | + Inherits from RejectionSamplingStrategy. Will generate up to loop_budget x number_of_samples requests. If no |
| 34 | + requirements are provided here or in sample(...), will only generate number_of_samples requests. |
| 35 | +
|
| 36 | + Classes that inherit from this must implement the `compare_strings` function. |
| 37 | +
|
| 38 | + Args: |
| 39 | + number_of_samples: Number of samples to generate and use for majority voting |
| 40 | + weighted: Not Implemented. If True, weights the score before getting the final majority vote |
| 41 | + loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. |
| 42 | + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. |
| 43 | +
|
| 44 | + Raises: |
| 45 | + AssertionError: If loop_budget is not greater than 0. |
| 46 | + """ |
| 47 | + super().__init__(loop_budget=loop_budget, requirements=requirements) |
| 48 | + self.number_of_samples = number_of_samples |
| 49 | + self.weighted = weighted |
| 50 | + self.symmetric = True |
| 51 | + |
| 52 | + @abc.abstractmethod |
| 53 | + def compare_strings(self, ref: str, pred: str) -> float: |
| 54 | + """This method is the abstract method for MBRD similarity metric.""" |
| 55 | + |
| 56 | + def maybe_apply_weighted(self, scr: np.ndarray): |
| 57 | + """Applies weights if self.weighted is True. Not Implemented.""" |
| 58 | + # TODO: not implemented yet |
| 59 | + if self.weighted: |
| 60 | + weights = np.asarray([1.0 for _ in range(len(scr))]) |
| 61 | + scr = scr * weights |
| 62 | + |
| 63 | + return scr |
| 64 | + |
| 65 | + async def sample( |
| 66 | + self, |
| 67 | + action: Component, |
| 68 | + context: Context, |
| 69 | + backend: Backend, |
| 70 | + requirements: list[Requirement] | None, |
| 71 | + *, |
| 72 | + validation_ctx: Context | None = None, |
| 73 | + format: type[BaseModelSubclass] | None = None, |
| 74 | + model_options: dict | None = None, |
| 75 | + tool_calls: bool = False, |
| 76 | + show_progress: bool = True, |
| 77 | + ) -> SamplingResult: |
| 78 | + """Samples using majority voting. |
| 79 | +
|
| 80 | + Args: |
| 81 | + action : The action object to be sampled. |
| 82 | + context: The context to be passed to the sampling strategy. |
| 83 | + backend: The backend used for generating samples. |
| 84 | + requirements: List of requirements to test against (merged with global requirements). |
| 85 | + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. |
| 86 | + format: output format for structured outputs; ignored for this sampling strategy. |
| 87 | + model_options: model options to pass to the backend during generation / validation. |
| 88 | + tool_calls: True if tool calls should be used during this sampling strategy. |
| 89 | + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. |
| 90 | +
|
| 91 | + Returns: |
| 92 | + SamplingResult: A result object indicating the success or failure of the sampling process. |
| 93 | + """ |
| 94 | + # execute sampling concurrently |
| 95 | + tasks: list[asyncio.Task[SamplingResult]] = [] |
| 96 | + for i in range(self.number_of_samples): |
| 97 | + task = asyncio.create_task( |
| 98 | + super().sample( |
| 99 | + action, |
| 100 | + context, |
| 101 | + backend, |
| 102 | + requirements, |
| 103 | + validation_ctx=validation_ctx, |
| 104 | + model_options=model_options, |
| 105 | + tool_calls=tool_calls, |
| 106 | + show_progress=show_progress, |
| 107 | + ) |
| 108 | + ) |
| 109 | + tasks.append(task) |
| 110 | + |
| 111 | + sampling_results = await asyncio.gather(*tasks) |
| 112 | + |
| 113 | + # collect results |
| 114 | + results: list[tuple[str, SamplingResult]] = [] |
| 115 | + for result in sampling_results: |
| 116 | + output = str(result.result) |
| 117 | + results.append((output, result)) |
| 118 | + assert len(results) > 0 |
| 119 | + |
| 120 | + # Create an array of len(results) x len(results) initialized to 0.0. |
| 121 | + scr = np.asarray( |
| 122 | + [[0.0 for _ in range(len(results))] for _ in range(len(results))] |
| 123 | + ) |
| 124 | + for i in range(len(results)): |
| 125 | + for j in range(len(results)): |
| 126 | + if j == i: |
| 127 | + scr[i][j] = 1.0 # self voting is 1. |
| 128 | + continue |
| 129 | + |
| 130 | + # upper triangle |
| 131 | + # For sample i compute votes against all j references |
| 132 | + if j > i: |
| 133 | + scr[i][j] = float( |
| 134 | + self.compare_strings(results[j][0], results[i][0]) |
| 135 | + ) |
| 136 | + continue |
| 137 | + |
| 138 | + else: |
| 139 | + if self.symmetric: |
| 140 | + scr[i][j] = scr[j][i] |
| 141 | + else: |
| 142 | + scr[i][j] = float( |
| 143 | + self.compare_strings(results[j][0], results[i][0]) |
| 144 | + ) |
| 145 | + continue |
| 146 | + |
| 147 | + # count votes |
| 148 | + summed_scr: np.ndarray = scr.sum(axis=0) |
| 149 | + |
| 150 | + # Apply weights |
| 151 | + weighed_scr = self.maybe_apply_weighted(summed_scr) |
| 152 | + |
| 153 | + maxR = int(weighed_scr.argmax()) |
| 154 | + |
| 155 | + return results[maxR][1] # return one of the MV answers |
| 156 | + |
| 157 | + |
| 158 | +class MajorityVotingStrategyForMath(BaseMBRDSampling): |
| 159 | + """MajorityVoting Sampling Strategy for Math Expressions.""" |
| 160 | + |
| 161 | + number_of_samples: int |
| 162 | + match_types: list[str] |
| 163 | + float_rounding: int |
| 164 | + strict: bool |
| 165 | + allow_set_relation_comp: bool |
| 166 | + |
| 167 | + def __init__( |
| 168 | + self, |
| 169 | + *, |
| 170 | + number_of_samples: int = 8, |
| 171 | + float_rounding: int = 6, |
| 172 | + strict: bool = True, |
| 173 | + allow_set_relation_comp: bool = False, |
| 174 | + weighted: bool = False, |
| 175 | + loop_budget: int = 1, |
| 176 | + requirements: list[Requirement] | None = None, |
| 177 | + ): |
| 178 | + """Initialize a new instance of MajorityVoting Sampling Strategy for Math with default parameters. |
| 179 | +
|
| 180 | + Will generate up to loop_budget x number_of_samples requests. If no |
| 181 | + requirements are provided here or in sample(...), will only generate number_of_samples requests. |
| 182 | +
|
| 183 | + Args: |
| 184 | + number_of_samples: Number of samples to generate and use for majority voting |
| 185 | + float_rounding: Number of decimal places to round floats to. Defaults to 6. |
| 186 | + strict: Whether to enforce strict comparison mode. Defaults to True. |
| 187 | + - In strict mode: Variables matter and sets are not comparable with tuples |
| 188 | + - In non-strict mode: Variables are matched by position and sets can be compared with tuples |
| 189 | + allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False. |
| 190 | + - If True, set - relation comparison will be allowed in all cases. |
| 191 | + - If False, set - relation comparison will be allowed only if the prediction is a set. |
| 192 | + weighted: Not Implemented. If True, weights the score before getting the final majority vote |
| 193 | + loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. |
| 194 | + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. |
| 195 | +
|
| 196 | + Raises: |
| 197 | + AssertionError: If loop_budget is not greater than 0. |
| 198 | + """ |
| 199 | + super().__init__( |
| 200 | + number_of_samples=number_of_samples, |
| 201 | + weighted=weighted, |
| 202 | + loop_budget=loop_budget, |
| 203 | + requirements=requirements, |
| 204 | + ) |
| 205 | + self.number_of_samples = number_of_samples |
| 206 | + # match_type: type of match latex, expr (match only so far) |
| 207 | + # - For math use "latex" or "expr" or both |
| 208 | + # - For general text similarity use "rougel" |
| 209 | + MATCH_TYPES = ["latex", "axpr"] |
| 210 | + self.match_types = MATCH_TYPES |
| 211 | + self.float_rounding = float_rounding |
| 212 | + self.strict = strict |
| 213 | + self.allow_set_relation_comp = allow_set_relation_comp |
| 214 | + |
| 215 | + # Note: symmetry is not implied for certain expressions, see: https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/README.md?plain=1#L183 |
| 216 | + self.symmetric = True |
| 217 | + |
| 218 | + # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 |
| 219 | + def compare_strings(self, ref: str, pred: str) -> float: |
| 220 | + """Helper function to compare strings using the math extraction metrics.""" |
| 221 | + # Convert string match_types to ExtractionTarget objects |
| 222 | + extraction_targets = [] |
| 223 | + for match_type in self.match_types: |
| 224 | + if match_type == "latex": |
| 225 | + extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0)) |
| 226 | + elif match_type == "expr": |
| 227 | + extraction_targets.append(ExprExtractionConfig()) |
| 228 | + |
| 229 | + # NOTE: Math-Verify parse and verify functions don't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None |
| 230 | + gold_parsed = parse(ref, extraction_targets, parsing_timeout=None) # type: ignore |
| 231 | + pred_parsed = parse(pred, extraction_targets, parsing_timeout=None) # type: ignore |
| 232 | + return float( |
| 233 | + verify( |
| 234 | + gold_parsed, |
| 235 | + pred_parsed, |
| 236 | + float_rounding=self.float_rounding, |
| 237 | + strict=self.strict, |
| 238 | + allow_set_relation_comp=self.allow_set_relation_comp, |
| 239 | + timeout_seconds=None, |
| 240 | + ) |
| 241 | + ) |
| 242 | + |
| 243 | + |
| 244 | +class MBRDRougeLStrategy(BaseMBRDSampling): |
| 245 | + """Sampling Strategy that uses RougeL to compute symbol-level distances for majority voting.""" |
| 246 | + |
| 247 | + match_types: list[str] |
| 248 | + scorer: RougeScorer |
| 249 | + |
| 250 | + def __init__( |
| 251 | + self, |
| 252 | + *, |
| 253 | + number_of_samples: int = 8, |
| 254 | + weighted: bool = False, |
| 255 | + loop_budget: int = 1, |
| 256 | + requirements: list[Requirement] | None = None, |
| 257 | + ): |
| 258 | + """Initialize a new instance of MBRDRougeL Sampling Strategy with default parameters. |
| 259 | +
|
| 260 | + Will generate up to loop_budget x number_of_samples requests. If no |
| 261 | + requirements are provided here or in sample(...), will only generate number_of_samples requests. |
| 262 | +
|
| 263 | + Args: |
| 264 | + number_of_samples: Number of samples to generate and use for majority voting |
| 265 | + weighted: Not Implemented. If True, weights the score before getting the final majority vote |
| 266 | + loop_budget: Inner rejection sampling number of times to iterate through the process. Must be greater than 0. |
| 267 | + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. |
| 268 | +
|
| 269 | + Raises: |
| 270 | + AssertionError: If loop_budget is not greater than 0. |
| 271 | + """ |
| 272 | + super().__init__( |
| 273 | + number_of_samples=number_of_samples, |
| 274 | + weighted=weighted, |
| 275 | + loop_budget=loop_budget, |
| 276 | + requirements=requirements, |
| 277 | + ) |
| 278 | + self.match_types = ["rougeL"] |
| 279 | + self.symmetric = True |
| 280 | + self.scorer = RougeScorer(self.match_types, use_stemmer=True) |
| 281 | + |
| 282 | + # https://github.com/huggingface/Math-Verify/blob/5d148cfaaf99214c2e4ffb4bc497ab042c592a7a/tests/test_all.py#L36 |
| 283 | + def compare_strings(self, ref: str, pred: str) -> float: |
| 284 | + """Helper function to compare strings using the math extraction metrics.""" |
| 285 | + scr: float = self.scorer.score(ref, pred)[self.match_types[-1]].fmeasure |
| 286 | + return scr |
0 commit comments