|
| 1 | +"""Sampling Strategies for budget forcing generation.""" |
| 2 | + |
| 3 | +from copy import deepcopy |
| 4 | + |
| 5 | +import tqdm |
| 6 | + |
| 7 | +from mellea.backends import Backend, BaseModelSubclass |
| 8 | +from mellea.backends.ollama import OllamaModelBackend |
| 9 | +from mellea.helpers.fancy_logger import FancyLogger |
| 10 | +from mellea.stdlib import funcs as mfuncs |
| 11 | +from mellea.stdlib.base import ModelOutputThunk |
| 12 | +from mellea.stdlib.requirement import Requirement, ValidationResult |
| 13 | +from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult |
| 14 | +from mellea.stdlib.sampling.base import Component, Context |
| 15 | +from mellea.stdlib.sampling_algos.budget_forcing_alg import think_budget_forcing |
| 16 | + |
| 17 | + |
| 18 | +class BudgetForcingSamplingStrategy(RejectionSamplingStrategy): |
| 19 | + """Budget forcing sampling class.""" |
| 20 | + |
| 21 | + think_max_tokens: int | None |
| 22 | + answer_max_tokens: int | None |
| 23 | + start_think_token: str | None |
| 24 | + end_think_token: str | None |
| 25 | + begin_response_token: str | None |
| 26 | + end_response_token: str |
| 27 | + think_more_suffix: str | None |
| 28 | + answer_suffix: str | None |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + *, |
| 33 | + think_max_tokens: int | None = 4096, |
| 34 | + answer_max_tokens: int | None = None, |
| 35 | + start_think_token: str | None = "<think>", |
| 36 | + end_think_token: str | None = "</think>", |
| 37 | + begin_response_token: str | None = "", |
| 38 | + end_response_token: str = "", |
| 39 | + think_more_suffix: str | None = "", |
| 40 | + answer_suffix: str | None = "", |
| 41 | + loop_budget: int = 1, |
| 42 | + requirements: list[Requirement] | None, |
| 43 | + ): |
| 44 | + r"""Initialize class. |
| 45 | +
|
| 46 | + Inherits from RejectionSamplingStrategy. |
| 47 | +
|
| 48 | + Args: |
| 49 | + think_max_tokens: Number of tokens for think block |
| 50 | + answer_max_tokens: Number of tokens allocated for answer portion, if set to None answer tokens will be unlimited |
| 51 | + start_think_token: Special start of think block token defaults to '<think>' |
| 52 | + end_think_token: Special end of think block token defaults to '</think>' |
| 53 | + begin_response_token: Special begin of response block token e.g. '<response>' defaults to "" |
| 54 | + end_response_token: Special end of response block token e.g. '</response>' defaults to "" |
| 55 | + think_more_suffix: Suffix for continue thinking e.g. "\nWait let's think more carefully" to force the model to think more, defaults to "". If set to "", no force thinking will be applied, the token budget will be become an upper bound. |
| 56 | + answer_suffix: Suffix to obtain final answer, default to "\nThe final answer is:" |
| 57 | + loop_budget: Number of times to iterate through the process. Must be greater than 0. |
| 58 | + requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. |
| 59 | +
|
| 60 | + Raises: |
| 61 | + AssertionError: If loop_budget is not greater than 0. |
| 62 | + """ |
| 63 | + super().__init__(loop_budget=loop_budget, requirements=requirements) |
| 64 | + self.think_max_tokens = think_max_tokens |
| 65 | + self.answer_max_tokens = answer_max_tokens |
| 66 | + self.start_think_token = start_think_token |
| 67 | + self.end_think_token = end_think_token |
| 68 | + self.begin_response_token = begin_response_token |
| 69 | + self.end_response_token = end_response_token |
| 70 | + self.think_more_suffix = think_more_suffix |
| 71 | + self.answer_suffix = answer_suffix |
| 72 | + |
| 73 | + async def sample( |
| 74 | + self, |
| 75 | + action: Component, |
| 76 | + context: Context, |
| 77 | + backend: Backend, |
| 78 | + requirements: list[Requirement] | None, |
| 79 | + *, |
| 80 | + validation_ctx: Context | None = None, |
| 81 | + format: type[BaseModelSubclass] | None = None, |
| 82 | + model_options: dict | None = None, |
| 83 | + tool_calls: bool = False, |
| 84 | + show_progress: bool = True, |
| 85 | + ) -> SamplingResult: |
| 86 | + """This method performs a sampling operation based on the given instruction. |
| 87 | +
|
| 88 | + Args: |
| 89 | + action : The action object to be sampled. |
| 90 | + context: The context to be passed to the sampling strategy. |
| 91 | + backend: The backend used for generating samples. |
| 92 | + requirements: List of requirements to test against (merged with global requirements). |
| 93 | + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. |
| 94 | + format: output format for structured outputs. |
| 95 | + model_options: model options to pass to the backend during generation / validation. |
| 96 | + tool_calls: True if tool calls should be used during this sampling strategy. |
| 97 | + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + SamplingResult: A result object indicating the success or failure of the sampling process. |
| 101 | +
|
| 102 | + Raises: |
| 103 | + AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. |
| 104 | + """ |
| 105 | + validation_ctx = validation_ctx if validation_ctx is not None else context |
| 106 | + |
| 107 | + flog = FancyLogger.get_logger() |
| 108 | + |
| 109 | + sampled_results: list[ModelOutputThunk] = [] |
| 110 | + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] |
| 111 | + sampled_actions: list[Component] = [] |
| 112 | + sample_contexts: list[Context] = [] |
| 113 | + |
| 114 | + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress |
| 115 | + # flag to determine whether we should show the pbar. |
| 116 | + show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO |
| 117 | + |
| 118 | + reqs = [] |
| 119 | + # global requirements supersede local requirements (global requirements can be defined by user) |
| 120 | + # Todo: re-evaluate if this makes sense |
| 121 | + if self.requirements is not None: |
| 122 | + reqs += self.requirements |
| 123 | + elif requirements is not None: |
| 124 | + reqs += requirements |
| 125 | + reqs = list(set(reqs)) |
| 126 | + |
| 127 | + loop_count = 0 |
| 128 | + loop_budget_range_iterator = ( |
| 129 | + tqdm.tqdm(range(self.loop_budget)) # type: ignore |
| 130 | + if show_progress |
| 131 | + else range(self.loop_budget) # type: ignore |
| 132 | + ) |
| 133 | + |
| 134 | + next_action = deepcopy(action) |
| 135 | + next_context = context |
| 136 | + for _ in loop_budget_range_iterator: # type: ignore |
| 137 | + loop_count += 1 |
| 138 | + if not show_progress: |
| 139 | + flog.info(f"Running loop {loop_count} of {self.loop_budget}") |
| 140 | + |
| 141 | + # TODO |
| 142 | + # tool_calls is not supported for budget forcing |
| 143 | + assert tool_calls is False, ( |
| 144 | + "tool_calls is not supported with budget forcing" |
| 145 | + ) |
| 146 | + # TODO |
| 147 | + assert isinstance(backend, OllamaModelBackend), ( |
| 148 | + "Only ollama backend supported with budget forcing" |
| 149 | + ) |
| 150 | + # run a generation pass with budget forcing |
| 151 | + result = think_budget_forcing( |
| 152 | + backend, |
| 153 | + next_action, |
| 154 | + ctx=context, |
| 155 | + format=format, |
| 156 | + tool_calls=tool_calls, |
| 157 | + think_max_tokens=self.think_max_tokens, |
| 158 | + answer_max_tokens=self.answer_max_tokens, |
| 159 | + start_think_token=self.start_think_token, |
| 160 | + end_think_token=self.end_think_token, |
| 161 | + think_more_suffix=self.think_more_suffix, |
| 162 | + answer_suffix=self.answer_suffix, |
| 163 | + model_options=model_options, |
| 164 | + ) |
| 165 | + result_ctx = next_context |
| 166 | + |
| 167 | + # validation pass |
| 168 | + val_scores_co = mfuncs.avalidate( |
| 169 | + reqs=reqs, |
| 170 | + context=result_ctx, |
| 171 | + backend=backend, |
| 172 | + output=result, |
| 173 | + format=format, |
| 174 | + model_options=model_options, |
| 175 | + # tool_calls=tool_calls # Don't support using tool calls in validation strategies. |
| 176 | + ) |
| 177 | + val_scores = await val_scores_co |
| 178 | + |
| 179 | + # match up reqs with scores |
| 180 | + constraint_scores = list(zip(reqs, val_scores)) |
| 181 | + |
| 182 | + # collect all data |
| 183 | + sampled_results.append(result) |
| 184 | + sampled_scores.append(constraint_scores) |
| 185 | + sampled_actions.append(next_action) |
| 186 | + sample_contexts.append(result_ctx) |
| 187 | + |
| 188 | + # if all vals are true -- break and return success |
| 189 | + if all(bool(s[1]) for s in constraint_scores): |
| 190 | + flog.info("SUCCESS") |
| 191 | + assert ( |
| 192 | + result._generate_log is not None |
| 193 | + ) # Cannot be None after generation. |
| 194 | + result._generate_log.is_final_result = True |
| 195 | + |
| 196 | + # SUCCESS !!!! |
| 197 | + return SamplingResult( |
| 198 | + result_index=len(sampled_results) - 1, |
| 199 | + success=True, |
| 200 | + sample_generations=sampled_results, |
| 201 | + sample_validations=sampled_scores, |
| 202 | + sample_contexts=sample_contexts, |
| 203 | + sample_actions=sampled_actions, |
| 204 | + ) |
| 205 | + |
| 206 | + else: |
| 207 | + # log partial success and continue |
| 208 | + count_valid = len([s for s in constraint_scores if bool(s[1])]) |
| 209 | + flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") |
| 210 | + |
| 211 | + # If we did not pass all constraints, update the instruction and try again. |
| 212 | + next_action, next_context = self.repair( |
| 213 | + next_context, |
| 214 | + result_ctx, |
| 215 | + sampled_actions, |
| 216 | + sampled_results, |
| 217 | + sampled_scores, |
| 218 | + ) |
| 219 | + |
| 220 | + flog.info( |
| 221 | + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." |
| 222 | + ) |
| 223 | + |
| 224 | + # if no valid result could be determined, find a last resort. |
| 225 | + best_failed_index = self.select_from_failure( |
| 226 | + sampled_actions, sampled_results, sampled_scores |
| 227 | + ) |
| 228 | + assert best_failed_index < len(sampled_results), ( |
| 229 | + "The select_from_failure method did not return a valid result. It has to selected from failed_results." |
| 230 | + ) |
| 231 | + |
| 232 | + assert ( |
| 233 | + sampled_results[best_failed_index]._generate_log is not None |
| 234 | + ) # Cannot be None after generation. |
| 235 | + sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore |
| 236 | + |
| 237 | + return SamplingResult( |
| 238 | + result_index=best_failed_index, |
| 239 | + success=False, |
| 240 | + sample_generations=sampled_results, |
| 241 | + sample_validations=sampled_scores, |
| 242 | + sample_actions=sampled_actions, |
| 243 | + sample_contexts=sample_contexts, |
| 244 | + ) |
0 commit comments