Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bd503cf
Initial commit - think budget-forcing - tests run - WIP
Aug 28, 2025
9385af8
adds zero-think case
Aug 28, 2025
fda0768
resolved type checking errors
Aug 29, 2025
ff03b6f
fixes typo and some scripts
Aug 29, 2025
d73c1ac
Merge branch 'main' into think_bf
yelkurdi Aug 29, 2025
1df37d5
Merge branch 'main' into think_bf
yelkurdi Sep 3, 2025
e013f92
backend interface using _raw_generate
Sep 5, 2025
556634b
Bump version number from 0.0.2 to 0.0.3 (#117)
nrfulton Sep 3, 2025
6b6599d
ci: Rename .mergify.yml to mergify.yml (#119)
avinash2692 Sep 3, 2025
396bf7a
docs: fix typo on README (#116)
mdevino Sep 4, 2025
cad893f
refactor: Full refactor of the Decompose CLI Tool & introduction of p…
tuliocoppola Sep 4, 2025
75e3d0e
moved the budget forcing function into mellea/stdlib/sampling_algos/b…
Sep 7, 2025
fd7a3b3
adds budget forcing fn
Sep 7, 2025
8f1a820
Merge branch 'main' into think_bf
yelkurdi Sep 7, 2025
599eac1
feat: adds think budget forcing - relocated test dir
Sep 7, 2025
8098128
Update budget_forcing.py
yelkurdi Sep 15, 2025
56a828a
Merge branch 'main' into think_bf
nrfulton Sep 19, 2025
3535b65
Merge branch 'main' into think_bf
yelkurdi Oct 6, 2025
ad076c5
merging main in-progress
yelkurdi Oct 14, 2025
66ae952
Merge branch 'main' into think_bf
yelkurdi Oct 14, 2025
05c8185
main branch updates
yelkurdi Oct 16, 2025
80e8485
updates to think_budget_forcing function to match sampling strategy i…
yelkurdi Oct 16, 2025
7f2c8f1
adds sampling strategy for budget forcing
yelkurdi Oct 16, 2025
2493ca1
minor fixes
yelkurdi Oct 17, 2025
dbadd21
feat: ollama generate_from_raw uses existing event loop
jakelorocco Oct 17, 2025
4396f81
Merge branch 'main' into think_bf
yelkurdi Oct 17, 2025
f4dc004
fix: add blocking prevention mech
jakelorocco Oct 20, 2025
c143ce4
Merge branch 'main' into jal/ollama-generate-from-raw
jakelorocco Oct 20, 2025
99b3156
Merge branch 'jal/ollama-generate-from-raw' into think_bf
yelkurdi Oct 20, 2025
8d91627
fixes of async inconsistencies and incorporating Jacob's branch
yelkurdi Oct 20, 2025
d0c9e41
Merge branch 'main' into think_bf
yelkurdi Nov 4, 2025
8796661
updates interface significantly after prompting `_generate_from_raw` …
yelkurdi Nov 6, 2025
5664a8d
minor fix to test case
yelkurdi Nov 6, 2025
d83fb84
minor updates
yelkurdi Nov 6, 2025
1a999b9
Merge branch 'main' into think_bf
yelkurdi Nov 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 163 additions & 1 deletion mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import inspect
import json
import re
from collections.abc import Callable
from enum import Enum
from typing import TYPE_CHECKING, Any
Expand All @@ -13,7 +14,7 @@
import requests
from huggingface_hub import snapshot_download
from openai.types.chat import ChatCompletion
from openai.types.completion import Completion
from openai.types.completion import Completion, CompletionUsage

import mellea.backends.model_ids as model_ids
from mellea.backends import BaseModelSubclass
Expand Down Expand Up @@ -467,6 +468,167 @@ def _generate_from_chat_context_standard(

return parsed_result

def generate_with_budget_forcing(
self,
action: CBlock,
*,
think_max_tokens: int = 4096,
answer_max_tokens: int | None = None,
start_think_token: str = "<think>",
end_think_token: str = "</think>",
begin_response_token: str = "",
end_response_token: str = "",
think_wait_suffix: str = "",
answer_suffix: str = "The final answer is:",
answer_regex: str = "boxed",
model_options: dict | None = None,
) -> tuple[str, int]:
"""Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structured in the following form: '<think> ... </think> summary answer'
The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
This implementation tries to follow the key outlines in the paper while ensuring stable and fail-safe operation.
This is performed via multi-step generation. The model will be called multiple times until requirements are met, in other words, the response will be assembled conditionally.

Args:
think_max_tokens: Budget in number of tokens allocated for the think block
answer_max_tokens: Budget in number of tokens allocated for the summary and answer block, None indicates generating till EoS
start_think_token: String indicating start of think block, default <think>
end_think_token: String indicating end of think block, default </think>
begin_response_token: Used by certain models, string indicating start of response block, e.g. "<response>", default None
end_response_token: Used by certain models, string indicating end of response block, e.g. "</response>", default None
think_wait_suffix: String to append to force continued thinking, e.g. "\nWait" if set to None we will not force additional thinking. Use None for upper-bound budget case
answer_suffix: String to append to force a final answer
answer_regex: Answer regex which indicates an answer is generated

Assumptions:
- The chat template is applied on prompt, with think mode enabled
- Model is think mode activated
- enabling prefix-caching improves performance

Limitations:
- Does not support batching
"""

model_opts = self._simplify_and_merge(model_options, is_chat_context=False)

responses = []
prompt = self.formatter.print(action)
if start_think_token:
prompt += start_think_token
responses.append(start_think_token)

backend_opts = self._make_backend_specific_and_remove(
model_opts, is_chat_context=False
)
# Generate thinking portion
# backend_opts["echo"] = True
# backend_opts["logprobs"] = 1
backend_opts["n"] = 1
rem_toks = think_max_tokens
gen_tok_count = 0
curr_prompt = prompt
min_step_len = 10 # minimum character length of step to be considered valid

# think block indefinite multi-step operation to satisfy user's budget
while True:
if rem_toks <= 0: # zero-think case
break

if rem_toks <= min_step_len: # minimum step length reached
break

backend_opts["max_tokens"] = rem_toks
try:
completion_response = self._client.completions.create(
model=self._hf_model_id, prompt=curr_prompt, **backend_opts
) # type: ignore
except openai.BadRequestError as e:
if openai_ollama_batching_error in e.message:
FancyLogger.get_logger().error(
"If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
"your requests will fail since ollama doesn't support batching requests."
)
raise e

# Necessary for type checker.
assert isinstance(completion_response.usage, CompletionUsage)
gen_tok_count += completion_response.usage.completion_tokens
rem_toks = think_max_tokens - gen_tok_count
response = completion_response.choices[0].text

if think_wait_suffix == "":
# non-strict budget form
responses.append(response)
break

if rem_toks <= 0:
responses.append(response)
break

else:
if end_think_token:
step = response.split(end_think_token)[0]
# model fails to produce thoughts, let's exit
if len(step.strip()) <= min_step_len:
responses.append(response)
break

# request more steps
step = f"{step} {think_wait_suffix}"
responses.append(step)
curr_prompt += step

response = "".join(responses)
if answer_regex is None or answer_suffix is None:
return response, gen_tok_count

# Now get a final answer if we need to
# TODO: Here we check if a final answer exists, technically we should check for an answer outside
# The think block, but we will use relaxed requirement of finding any answer in the model's response.
# Consider a strict structural approach in the future.
# e.g.
# answer_blk = response.split(end_think_token)[-1]

# Check if answer in response
matches = re.findall(answer_regex, response, re.DOTALL)
if len(matches) > 0:
return response, gen_tok_count

# Answer is not in response, let's force an answer
if end_think_token and end_think_token not in response:
response += f" {end_think_token}"

if begin_response_token and begin_response_token not in response:
response += f" {begin_response_token}"

if answer_suffix:
response += f" {answer_suffix}"

# update original prompt with assembled response
prompt += response
if answer_max_tokens is not None:
backend_opts["max_tokens"] = answer_max_tokens

else:
backend_opts.pop("max_tokens", None) # generate unconditionally

try:
completion_response = self._client.completions.create(
model=self._hf_model_id, prompt=prompt, **backend_opts
) # type: ignore
except openai.BadRequestError as e:
if openai_ollama_batching_error in e.message:
FancyLogger.get_logger().error(
"If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
"your requests will fail since ollama doesn't support batching requests."
)
raise e

# Necessary for type checker.
assert isinstance(completion_response.usage, CompletionUsage)
response += completion_response.choices[0].text
gen_tok_count += completion_response.usage.completion_tokens
return response, gen_tok_count

def _generate_from_raw(
self,
actions: list[Component | CBlock],
Expand Down
2 changes: 2 additions & 0 deletions test/backends/test_think_budget_forcing/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
vllm.err
vllm.log
23 changes: 23 additions & 0 deletions test/backends/test_think_budget_forcing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

# Test for OpenAI API served by VLLM

## Requirement

anaconda / miniconda / miniforge.

Make sure to run the test with multiple cores available (e.g. in a cloud instance / cluster job).
Although you may think 1 core is enough,
vllm could get stuck due to deadlock if so.

## Installation

Needs to be done only once.
I creates a new conda environment named "mallea_tbf" only for the purposes of testing or contributing to the think budget-forcing feature.

Run `./install.sh`

## Testing

``` shell
./run_test.sh
```
7 changes: 7 additions & 0 deletions test/backends/test_think_budget_forcing/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

name: mellea_tbf
channels:
- conda-forge
dependencies:
- python=3.12 # note: at the time of writing, xformer (< vllm) has a broken wheel for 3.13. https://github.com/facebookresearch/xformers/issues/740#issuecomment-2753869337
- uv
18 changes: 18 additions & 0 deletions test/backends/test_think_budget_forcing/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash -xe

ENV_NAME=mellea_tbf
conda env remove -y -n $ENV_NAME || true
conda env create -f $(readlink -ef $(dirname $0))/environment.yml

in-conda (){
conda run -n $ENV_NAME $@
}


cd ../../../
in-conda uv pip install -e .
cd -
in-conda uv pip install pre-commit
in-conda uv pip install pytest
in-conda uv pip install vllm==0.10.0
in-conda uv pip install outlines
24 changes: 24 additions & 0 deletions test/backends/test_think_budget_forcing/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

ENV_NAME=mellea_tbf
eval "$(conda shell.bash hook)"
conda activate $ENV_NAME

dir=$(readlink -ef $(dirname $0))
rm $dir/vllm.log $dir/vllm.err

bash $dir/serve.sh &
vllm_pid=$!

trap "kill -SIGINT $vllm_pid ; wait" EXIT

while sleep 1 ; do
if grep -q "Application startup complete." $dir/vllm.err
then
break
fi
done

python test_think_budget_forcing.py


37 changes: 37 additions & 0 deletions test/backends/test_think_budget_forcing/serve.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash

# @Masa note:
# the following code is a bash snippet Kristian gave me
# for how to run vllm with lora adapter loaded.

# HF_GRANITE_ALORA_SNAPSHOT=${HF_HOME:-$HOME/.cache/huggingface}
# HF_GRANITE_ALORA_SNAPSHOT+=/hub/
# HF_GRANITE_ALORA_SNAPSHOT+=models--ibm-granite--granite-3.2-8b-alora-requirement-check/
# HF_GRANITE_ALORA_SNAPSHOT+=snapshots/d55a7a7f5796609bc938c5c151a864cfcc6ab54e

# vllm serve ibm-granite/granite-3.2-8b-instruct \
# --enable-lora \
# --lora-modules "{\"name\": \"ibm-granite/granite-3.2-8b-alora-requirement-check\", \"path\": \"${HF_GRANITE_ALORA_SNAPSHOT}\", \"base_model_name\": \"ibm-granite/granite-3.2-8b-instruct\"}" \
# --dtype bfloat16 \
# --max-lora-rank 64 \
# --enable-prefix-caching

# However, in our test, we do not load the alora when we serve.
# In this test, we use the dynamic loading interface from
# https://docs.vllm.ai/en/stable/features/lora.html#dynamically-serving-lora-adapters

# Using this feature requires the following environment variable.
# If you use conda/miniforge,
# this variable must have been set already when you set up the environment.
# see environment.yml.
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True

echo "launching a vllm server. Logs are found in $(readlink -ef $(dirname $0))/vllm.log"
# At the time of writing this code, Granite 4.4 vLLM serving did not support prefix-caching
# --enable-prefix-caching \
vllm serve ibm-granite/granite-4.0-tiny-preview \
--dtype bfloat16 \
> $(readlink -ef $(dirname $0))/vllm.log \
2> $(readlink -ef $(dirname $0))/vllm.err


Loading
Loading