Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bd503cf
Initial commit - think budget-forcing - tests run - WIP
Aug 28, 2025
9385af8
adds zero-think case
Aug 28, 2025
fda0768
resolved type checking errors
Aug 29, 2025
ff03b6f
fixes typo and some scripts
Aug 29, 2025
d73c1ac
Merge branch 'main' into think_bf
yelkurdi Aug 29, 2025
1df37d5
Merge branch 'main' into think_bf
yelkurdi Sep 3, 2025
e013f92
backend interface using _raw_generate
Sep 5, 2025
556634b
Bump version number from 0.0.2 to 0.0.3 (#117)
nrfulton Sep 3, 2025
6b6599d
ci: Rename .mergify.yml to mergify.yml (#119)
avinash2692 Sep 3, 2025
396bf7a
docs: fix typo on README (#116)
mdevino Sep 4, 2025
cad893f
refactor: Full refactor of the Decompose CLI Tool & introduction of p…
tuliocoppola Sep 4, 2025
75e3d0e
moved the budget forcing function into mellea/stdlib/sampling_algos/b…
Sep 7, 2025
fd7a3b3
adds budget forcing fn
Sep 7, 2025
8f1a820
Merge branch 'main' into think_bf
yelkurdi Sep 7, 2025
599eac1
feat: adds think budget forcing - relocated test dir
Sep 7, 2025
8098128
Update budget_forcing.py
yelkurdi Sep 15, 2025
56a828a
Merge branch 'main' into think_bf
nrfulton Sep 19, 2025
3535b65
Merge branch 'main' into think_bf
yelkurdi Oct 6, 2025
ad076c5
merging main in-progress
yelkurdi Oct 14, 2025
66ae952
Merge branch 'main' into think_bf
yelkurdi Oct 14, 2025
05c8185
main branch updates
yelkurdi Oct 16, 2025
80e8485
updates to think_budget_forcing function to match sampling strategy i…
yelkurdi Oct 16, 2025
7f2c8f1
adds sampling strategy for budget forcing
yelkurdi Oct 16, 2025
2493ca1
minor fixes
yelkurdi Oct 17, 2025
dbadd21
feat: ollama generate_from_raw uses existing event loop
jakelorocco Oct 17, 2025
4396f81
Merge branch 'main' into think_bf
yelkurdi Oct 17, 2025
f4dc004
fix: add blocking prevention mech
jakelorocco Oct 20, 2025
c143ce4
Merge branch 'main' into jal/ollama-generate-from-raw
jakelorocco Oct 20, 2025
99b3156
Merge branch 'jal/ollama-generate-from-raw' into think_bf
yelkurdi Oct 20, 2025
8d91627
fixes of async inconsistencies and incorporating Jacob's branch
yelkurdi Oct 20, 2025
d0c9e41
Merge branch 'main' into think_bf
yelkurdi Nov 4, 2025
8796661
updates interface significantly after prompting `_generate_from_raw` …
yelkurdi Nov 6, 2025
5664a8d
minor fix to test case
yelkurdi Nov 6, 2025
d83fb84
minor updates
yelkurdi Nov 6, 2025
1a999b9
Merge branch 'main' into think_bf
yelkurdi Nov 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions mellea/stdlib/sampling_algos/budget_forcing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from mellea.stdlib.session import MelleaSession
from mellea.stdlib.base import (
CBlock,
Component,
GenerateLog,
ModelOutputThunk,
)
import re

def think_budget_forcing(
session: MelleaSession,
action: CBlock | Component,
*,
think_max_tokens: int = 4096,
answer_max_tokens: int | None = None,
start_think_token: str = "<think>",
end_think_token: str = "</think>",
begin_response_token: str = "",
end_response_token: str = "",
think_wait_suffix: str = "",
answer_suffix: str = "The final answer is:",
answer_regex: str = r"\\boxed{.*?}",
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
):

"""Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structured in the following form: '<think> ... </think> summary answer'
The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
This implementation tries to follow the key outlines in the paper while ensuring stable and fail-safe operation.
This is performed via multi-step generation. The model will be called multiple times until requirements are met, in other words, the response will be assembled conditionally.

Args:
think_max_tokens: Budget in number of tokens allocated for the think block
answer_max_tokens: Budget in number of tokens allocated for the summary and answer block, None indicates generating till EoS
start_think_token: String indicating start of think block, default <think>
end_think_token: String indicating end of think block, default </think>
begin_response_token: Used by certain models, string indicating start of response block, e.g. "<response>", default None
end_response_token: Used by certain models, string indicating end of response block, e.g. "</response>", default None
think_wait_suffix: String to append to force continued thinking, e.g. "\nWait" if set to None we will not force additional thinking. Use None for upper-bound budget case
answer_suffix: String to append to force a final answer
answer_regex: Answer regex which indicates an answer is generated

Assumptions:
- The chat template is applied on prompt, with think mode enabled
- Model is think mode activated
- enabling prefix-caching improves performance

Limitations:
- Does not support batching
"""

backend = session.backend
model_options = backend._simplify_and_merge(model_options, is_chat_context=False)

responses = []
prompt = backend.formatter.print(action)
if start_think_token:
prompt += start_think_token
responses.append(start_think_token)

# Generate thinking portion
# model_options["echo"] = True
# model_options["logprobs"] = 1
model_options["n"] = 1
rem_toks = think_max_tokens
gen_tok_count = 0
curr_prompt = prompt
min_step_len = 10 # minimum character length of step to be considered valid

# think block indefinite multi-step operation to satisfy user's budget
while True:
if rem_toks <= 0: # zero-think case
break

if rem_toks <= min_step_len: # minimum step length reached
break

model_options["max_tokens"] = rem_toks
# TODO workaround to obtain generated token counts
# The token count should be relayed by openai's CompletionUsage
model_options["logprobs"] = 1 # To get number of generated tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_options

result = backend._generate_from_raw([prompt], model_options=model_options, generate_logs=generate_logs)
gen_tok_count += len(result[0]._meta['oai_completion_response']['logprobs']['token_logprobs'])
rem_toks = think_max_tokens - gen_tok_count
response = result[0].value

if think_wait_suffix == "":
# non-strict budget form
responses.append(response)
break

if rem_toks <= 0:
responses.append(response)
break

else:
if end_think_token:
step = response.split(end_think_token)[0]
# model fails to produce thoughts, let's exit
if len(step.strip()) <= min_step_len:
responses.append(response)
break

# request more steps
step = f"{step} {think_wait_suffix}"
responses.append(step)
curr_prompt += step

response = "".join(responses)
if answer_regex is None or answer_suffix is None:
return response, gen_tok_count

# Now get a final answer if we need to
# TODO: Here we check if a final answer exists, technically we should check for an answer outside
# The think block, but we will use relaxed requirement of finding any answer in the model's response.
# Consider a strict structural approach in the future.
# e.g.
# answer_blk = response.split(end_think_token)[-1]

# Check if answer in response
matches = re.findall(answer_regex, response, re.DOTALL)
if len(matches) > 0:
return response, gen_tok_count

# Answer is not in response, let's force an answer
if end_think_token and end_think_token not in response:
response += f" {end_think_token}"

if begin_response_token and begin_response_token not in response:
response += f" {begin_response_token}"

if answer_suffix:
response += f" {answer_suffix}"

# update original prompt with assembled response
prompt += response
if answer_max_tokens is not None:
model_options["max_tokens"] = answer_max_tokens

else:
model_options.pop("max_tokens", None) # generate unconditionally

model_options["logprobs"] = 1 # To get number of generated tokens
result = backend._generate_from_raw([prompt], model_options=model_options, generate_logs=generate_logs)
response += result[0].value
gen_tok_count += len(result[0]._meta['oai_completion_response']['logprobs']['token_logprobs'])
return response, gen_tok_count

2 changes: 2 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
vllm.err
vllm.log
23 changes: 23 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

# Test for OpenAI API served by VLLM

## Requirement

anaconda / miniconda / miniforge.

Make sure to run the test with multiple cores available (e.g. in a cloud instance / cluster job).
Although you may think 1 core is enough,
vllm could get stuck due to deadlock if so.

## Installation

Run the `install.sh` script, which needs to be done only once.
The script creates a new conda environment named "mellea_tbf" only for the purposes of testing or contributing to the think budget-forcing feature.

Run `./install.sh`

## Testing

``` shell
./run_test.sh
```
7 changes: 7 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

name: mellea_tbf
channels:
- conda-forge
dependencies:
- python=3.12 # note: at the time of writing, xformer (< vllm) has a broken wheel for 3.13. https://github.com/facebookresearch/xformers/issues/740#issuecomment-2753869337
- uv
10 changes: 10 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/exec_sampling_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

source set_variables.sh

eval "$(conda shell.bash hook)"
conda activate $ENV_NAME

export LOCAL_TEST_MODEL

python test_think_budget_forcing.py
22 changes: 22 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash -xe

source set_variables.sh

conda env remove -y -n $ENV_NAME || true
conda env create -f $(readlink -f $(dirname $0))/environment.yml

in-conda (){
conda run -n $ENV_NAME $@
}


cd ../../../
in-conda uv pip install -e .
cd -
in-conda uv pip install pre-commit
in-conda uv pip install pytest
in-conda uv pip install vllm==0.10.0
in-conda uv pip install outlines
# in-conda uv pip install unsloth
in-conda uv pip install ipdb

24 changes: 24 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

source set_variables.sh

eval "$(conda shell.bash hook)"
conda activate $ENV_NAME

rm $VLLM_LOG $VLLM_ERR

bash ./serve.sh &
VLLM_PID=$!

trap "kill -SIGINT $VLLM_PID ; wait" EXIT

while sleep 1 ; do
if grep -q "Application startup complete." $VLLM_ERR
then
break
fi
done

bash exec_sampling_test.sh


16 changes: 16 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/serve.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

source set_variables.sh
eval "$(conda shell.bash hook)"
conda activate $ENV_NAME
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True

echo "launching a vllm server. Logs are found in $(readlink -ef $(dirname $0))/vllm.log"
# At the time of writing this code, Granite 4.4 vLLM serving did not support prefix-caching
# --enable-prefix-caching \
vllm serve $LOCAL_TEST_MODEL \
--dtype bfloat16 \
> $VLLM_LOG \
2> $VLLM_ERR


8 changes: 8 additions & 0 deletions test/stdlib_basics/test_think_budget_forcing/set_variables.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

PYTHONBREAKPOINT="ipdb.set_trace"
LOCAL_TEST_MODEL="ibm-granite/granite-4.0-tiny-preview"
ENV_NAME=mellea_tbf
DIR=$(readlink -ef $(dirname $0))
VLLM_LOG=$DIR/vllm.log
VLLM_ERR=$DIR/vllm.err
Loading
Loading