Skip to content

Commit a2e29e6

Browse files
yelkurdiYousef El-Kurdinrfultonavinash2692mdevino
authored
feat: Adds think budget-forcing (#107)
* Initial commit - think budget-forcing - tests run - WIP * adds zero-think case * resolved type checking errors * fixes typo and some scripts * backend interface using _raw_generate * Bump version number from 0.0.2 to 0.0.3 (#117) * ci: Rename .mergify.yml to mergify.yml (#119) * docs: fix typo on README (#116) Signed-off-by: Mateus Devino <[email protected]> * refactor: Full refactor of the Decompose CLI Tool & introduction of prompt_modules (#105) * Implements "prompt_modules" and complete refactor of the "decompose" feature * typo: missing period * minor fix: changed the "NotRequired" import * fix: minor fixes * moves prompt_modules to utils * moves decompose modules to appropriate path * refactor: moves prompt_modules to cli scope Signed-off-by: Tulio Coppola <[email protected]> * adds README.md to write later Signed-off-by: Tulio Coppola <[email protected]> --------- Signed-off-by: Tulio Coppola <[email protected]> Co-authored-by: Tulio Coppola <[email protected]> Co-authored-by: Nathan Fulton <[email protected]> * moved the budget forcing function into mellea/stdlib/sampling_algos/budget_forcing.py * adds budget forcing fn * feat: adds think budget forcing - relocated test dir * Update budget_forcing.py corrected default argument * merging main in-progress * main branch updates * updates to think_budget_forcing function to match sampling strategy interface * adds sampling strategy for budget forcing * minor fixes * feat: ollama generate_from_raw uses existing event loop * fix: add blocking prevention mech * fixes of async inconsistencies and incorporating Jacob's branch * updates interface significantly after prompting `_generate_from_raw` to public `generate_from_raw` * minor fix to test case * minor updates --------- Signed-off-by: Mateus Devino <[email protected]> Signed-off-by: Tulio Coppola <[email protected]> Co-authored-by: Yousef El-Kurdi <[email protected]> Co-authored-by: Nathan Fulton <[email protected]> Co-authored-by: Avinash Balakrishnan <[email protected]> Co-authored-by: Mateus Devino <[email protected]> Co-authored-by: Tulio Coppola <[email protected]> Co-authored-by: Tulio Coppola <[email protected]> Co-authored-by: jakelorocco <[email protected]> Co-authored-by: jakelorocco <[email protected]>
1 parent 7fa0891 commit a2e29e6

File tree

3 files changed

+504
-0
lines changed

3 files changed

+504
-0
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Sampling Strategies for budget forcing generation."""
2+
3+
from copy import deepcopy
4+
5+
import tqdm
6+
7+
from mellea.backends import Backend, BaseModelSubclass
8+
from mellea.backends.ollama import OllamaModelBackend
9+
from mellea.helpers.fancy_logger import FancyLogger
10+
from mellea.stdlib import funcs as mfuncs
11+
from mellea.stdlib.base import ModelOutputThunk
12+
from mellea.stdlib.requirement import Requirement, ValidationResult
13+
from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult
14+
from mellea.stdlib.sampling.base import Component, Context
15+
from mellea.stdlib.sampling_algos.budget_forcing_alg import think_budget_forcing
16+
17+
18+
class BudgetForcingSamplingStrategy(RejectionSamplingStrategy):
19+
"""Budget forcing sampling class."""
20+
21+
think_max_tokens: int | None
22+
answer_max_tokens: int | None
23+
start_think_token: str | None
24+
end_think_token: str | None
25+
begin_response_token: str | None
26+
end_response_token: str
27+
think_more_suffix: str | None
28+
answer_suffix: str | None
29+
30+
def __init__(
31+
self,
32+
*,
33+
think_max_tokens: int | None = 4096,
34+
answer_max_tokens: int | None = None,
35+
start_think_token: str | None = "<think>",
36+
end_think_token: str | None = "</think>",
37+
begin_response_token: str | None = "",
38+
end_response_token: str = "",
39+
think_more_suffix: str | None = "",
40+
answer_suffix: str | None = "",
41+
loop_budget: int = 1,
42+
requirements: list[Requirement] | None,
43+
):
44+
r"""Initialize class.
45+
46+
Inherits from RejectionSamplingStrategy.
47+
48+
Args:
49+
think_max_tokens: Number of tokens for think block
50+
answer_max_tokens: Number of tokens allocated for answer portion, if set to None answer tokens will be unlimited
51+
start_think_token: Special start of think block token defaults to '<think>'
52+
end_think_token: Special end of think block token defaults to '</think>'
53+
begin_response_token: Special begin of response block token e.g. '<response>' defaults to ""
54+
end_response_token: Special end of response block token e.g. '</response>' defaults to ""
55+
think_more_suffix: Suffix for continue thinking e.g. "\nWait let's think more carefully" to force the model to think more, defaults to "". If set to "", no force thinking will be applied, the token budget will be become an upper bound.
56+
answer_suffix: Suffix to obtain final answer, default to "\nThe final answer is:"
57+
loop_budget: Number of times to iterate through the process. Must be greater than 0.
58+
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.
59+
60+
Raises:
61+
AssertionError: If loop_budget is not greater than 0.
62+
"""
63+
super().__init__(loop_budget=loop_budget, requirements=requirements)
64+
self.think_max_tokens = think_max_tokens
65+
self.answer_max_tokens = answer_max_tokens
66+
self.start_think_token = start_think_token
67+
self.end_think_token = end_think_token
68+
self.begin_response_token = begin_response_token
69+
self.end_response_token = end_response_token
70+
self.think_more_suffix = think_more_suffix
71+
self.answer_suffix = answer_suffix
72+
73+
async def sample(
74+
self,
75+
action: Component,
76+
context: Context,
77+
backend: Backend,
78+
requirements: list[Requirement] | None,
79+
*,
80+
validation_ctx: Context | None = None,
81+
format: type[BaseModelSubclass] | None = None,
82+
model_options: dict | None = None,
83+
tool_calls: bool = False,
84+
show_progress: bool = True,
85+
) -> SamplingResult:
86+
"""This method performs a sampling operation based on the given instruction.
87+
88+
Args:
89+
action : The action object to be sampled.
90+
context: The context to be passed to the sampling strategy.
91+
backend: The backend used for generating samples.
92+
requirements: List of requirements to test against (merged with global requirements).
93+
validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
94+
format: output format for structured outputs.
95+
model_options: model options to pass to the backend during generation / validation.
96+
tool_calls: True if tool calls should be used during this sampling strategy.
97+
show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
98+
99+
Returns:
100+
SamplingResult: A result object indicating the success or failure of the sampling process.
101+
102+
Raises:
103+
AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling.
104+
"""
105+
validation_ctx = validation_ctx if validation_ctx is not None else context
106+
107+
flog = FancyLogger.get_logger()
108+
109+
sampled_results: list[ModelOutputThunk] = []
110+
sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = []
111+
sampled_actions: list[Component] = []
112+
sample_contexts: list[Context] = []
113+
114+
# The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress
115+
# flag to determine whether we should show the pbar.
116+
show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO
117+
118+
reqs = []
119+
# global requirements supersede local requirements (global requirements can be defined by user)
120+
# Todo: re-evaluate if this makes sense
121+
if self.requirements is not None:
122+
reqs += self.requirements
123+
elif requirements is not None:
124+
reqs += requirements
125+
reqs = list(set(reqs))
126+
127+
loop_count = 0
128+
loop_budget_range_iterator = (
129+
tqdm.tqdm(range(self.loop_budget)) # type: ignore
130+
if show_progress
131+
else range(self.loop_budget) # type: ignore
132+
)
133+
134+
next_action = deepcopy(action)
135+
next_context = context
136+
for _ in loop_budget_range_iterator: # type: ignore
137+
loop_count += 1
138+
if not show_progress:
139+
flog.info(f"Running loop {loop_count} of {self.loop_budget}")
140+
141+
# TODO
142+
# tool_calls is not supported for budget forcing
143+
assert tool_calls is False, (
144+
"tool_calls is not supported with budget forcing"
145+
)
146+
# TODO
147+
assert isinstance(backend, OllamaModelBackend), (
148+
"Only ollama backend supported with budget forcing"
149+
)
150+
# run a generation pass with budget forcing
151+
result = think_budget_forcing(
152+
backend,
153+
next_action,
154+
ctx=context,
155+
format=format,
156+
tool_calls=tool_calls,
157+
think_max_tokens=self.think_max_tokens,
158+
answer_max_tokens=self.answer_max_tokens,
159+
start_think_token=self.start_think_token,
160+
end_think_token=self.end_think_token,
161+
think_more_suffix=self.think_more_suffix,
162+
answer_suffix=self.answer_suffix,
163+
model_options=model_options,
164+
)
165+
result_ctx = next_context
166+
167+
# validation pass
168+
val_scores_co = mfuncs.avalidate(
169+
reqs=reqs,
170+
context=result_ctx,
171+
backend=backend,
172+
output=result,
173+
format=format,
174+
model_options=model_options,
175+
# tool_calls=tool_calls # Don't support using tool calls in validation strategies.
176+
)
177+
val_scores = await val_scores_co
178+
179+
# match up reqs with scores
180+
constraint_scores = list(zip(reqs, val_scores))
181+
182+
# collect all data
183+
sampled_results.append(result)
184+
sampled_scores.append(constraint_scores)
185+
sampled_actions.append(next_action)
186+
sample_contexts.append(result_ctx)
187+
188+
# if all vals are true -- break and return success
189+
if all(bool(s[1]) for s in constraint_scores):
190+
flog.info("SUCCESS")
191+
assert (
192+
result._generate_log is not None
193+
) # Cannot be None after generation.
194+
result._generate_log.is_final_result = True
195+
196+
# SUCCESS !!!!
197+
return SamplingResult(
198+
result_index=len(sampled_results) - 1,
199+
success=True,
200+
sample_generations=sampled_results,
201+
sample_validations=sampled_scores,
202+
sample_contexts=sample_contexts,
203+
sample_actions=sampled_actions,
204+
)
205+
206+
else:
207+
# log partial success and continue
208+
count_valid = len([s for s in constraint_scores if bool(s[1])])
209+
flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}")
210+
211+
# If we did not pass all constraints, update the instruction and try again.
212+
next_action, next_context = self.repair(
213+
next_context,
214+
result_ctx,
215+
sampled_actions,
216+
sampled_results,
217+
sampled_scores,
218+
)
219+
220+
flog.info(
221+
f"Invoking select_from_failure after {len(sampled_results)} failed attempts."
222+
)
223+
224+
# if no valid result could be determined, find a last resort.
225+
best_failed_index = self.select_from_failure(
226+
sampled_actions, sampled_results, sampled_scores
227+
)
228+
assert best_failed_index < len(sampled_results), (
229+
"The select_from_failure method did not return a valid result. It has to selected from failed_results."
230+
)
231+
232+
assert (
233+
sampled_results[best_failed_index]._generate_log is not None
234+
) # Cannot be None after generation.
235+
sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore
236+
237+
return SamplingResult(
238+
result_index=best_failed_index,
239+
success=False,
240+
sample_generations=sampled_results,
241+
sample_validations=sampled_scores,
242+
sample_actions=sampled_actions,
243+
sample_contexts=sample_contexts,
244+
)

0 commit comments

Comments
 (0)