Skip to content

Commit bd503cf

Browse files
author
Yousef El-Kurdi
committed
Initial commit - think budget-forcing - tests run - WIP
1 parent d92a44f commit bd503cf

File tree

8 files changed

+344
-0
lines changed

8 files changed

+344
-0
lines changed

mellea/backends/openai.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,155 @@ def _generate_from_chat_context_standard(
467467

468468
return parsed_result
469469

470+
def generate_with_budget_forcing(
471+
self,
472+
action: CBlock,
473+
*,
474+
think_max_tokens: int = 3072,
475+
answer_max_tokens: int | None = None,
476+
start_think_token: str | None = "<think>",
477+
end_think_token: str | None = "</think>",
478+
begin_response_token: str | None = None,
479+
end_response_token: str | None = None,
480+
think_wait_suffix: str | None = None,
481+
answer_suffix: str | None = "The final answer is:",
482+
answer_token: str | None = "boxed",
483+
model_options: dict | None = None,
484+
) -> list[ModelOutputThunk]:
485+
"""Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structued in the following form: '<think> ... </think> summary answer'
486+
The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
487+
This implementation tries to follow the key outlines in the paper while ensuring stable and fail-safe operation.
488+
This is performed via multi-step generation. The model will be called multiple times until requirements are met, in other words, the response will be assembeled conditionally.
489+
490+
Args:
491+
think_max_tokens: Budget in number of tokens allocated for the think block
492+
answer_max_tokens: Budget in number of tokens allocated for the summary and answer block, None indicates generating till EoS
493+
start_think_token: String indicating start of think block, default <think>
494+
end_think_token: String indicating end of think block, default </think>
495+
begin_response_token: Used by certain models, string indicating start of response block, e.g. "<response>", default None
496+
end_response_token: Used by certain models, string indicating end of response block, e.g. "</response>", default None
497+
think_wait_suffix: String to append to force continued thinking, e.g. "\nWait" if set to None we will not force additional thinking. Use None for upper-bound budget case
498+
answer_suffix: String to append to force a final answer
499+
answer_token: Token that indicates an answer is generated
500+
501+
Assumptions:
502+
- The chat template is applied on prompt, with think mode enabled
503+
- Model is think mode activated
504+
- enabling prefix-caching improves performance
505+
506+
Limitations:
507+
- Does not support batching
508+
"""
509+
510+
model_opts = self._simplify_and_merge(model_options, is_chat_context=False)
511+
512+
responses = []
513+
prompt = self.formatter.print(action)
514+
if start_think_token is not None:
515+
prompt += start_think_token
516+
responses.append(start_think_token)
517+
backend_opts = self._make_backend_specific_and_remove(
518+
model_opts, is_chat_context=False
519+
)
520+
# Generate thinking portion
521+
max_tok_thd = 0.8
522+
backend_opts["max_tokens"] = think_max_tokens
523+
# backend_opts["echo"] = True
524+
# backend_opts["logprobs"] = 1
525+
backend_opts["n"] = 1
526+
gen_tok_count = 0
527+
curr_prompt = prompt
528+
min_step_len = 10 # minimum character length of step to be considered valid
529+
530+
# think block indefinite multi-step operation to satisfy user's budget
531+
while True:
532+
try:
533+
completion_response: Completion = self._client.completions.create(
534+
model=self._hf_model_id, prompt=curr_prompt, **backend_opts
535+
) # type: ignore
536+
except openai.BadRequestError as e:
537+
if openai_ollama_batching_error in e.message:
538+
FancyLogger.get_logger().error(
539+
"If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, "
540+
"your requests will fail since ollama doesn't support batching requests."
541+
)
542+
raise e
543+
544+
gen_tok_count += completion_response.usage.completion_tokens
545+
response = completion_response.choices[0].text
546+
if think_wait_suffix is None:
547+
responses.append(response)
548+
break
549+
550+
if gen_tok_count >= max_tok_thd * think_max_tokens:
551+
responses.append(response)
552+
break
553+
554+
else:
555+
step = response.split(end_think_token)[0]
556+
# model fails to produce thoughts, let's exit
557+
if len(step.strip()) <= min_step_len:
558+
responses.append(response)
559+
break
560+
561+
# request more steps
562+
step = f"{step} {think_wait_suffix}"
563+
responses.append(step)
564+
curr_prompt += step
565+
566+
response = "".join(responses)
567+
### debug obtaining final answer
568+
# response = response.split(end_think_token)[0]
569+
# response = response.replace(answer_token, "")
570+
###
571+
if answer_token is None or answer_suffix is None:
572+
return response, gen_tok_count
573+
574+
# Now get a final answer if we need to
575+
# TODO: Here we check if a final answer exists, technically we should check for an answer outside
576+
# The think block, but we will use relaxed requirement of finding any answer in the model's response.
577+
# Consider a strict structural approach in the future.
578+
# e.g.
579+
# ans_portion = response.split(end_think_token)[-1]
580+
# if answer_token in ans_portion:
581+
# return response, gen_tok_count
582+
583+
if answer_token in response:
584+
return response, gen_tok_count
585+
586+
# Answer is not in response, let's force an answer
587+
if end_think_token not in response:
588+
response = (
589+
f"{response} {end_think_token}{begin_response_token} {answer_suffix}"
590+
)
591+
592+
else:
593+
response = f"{response} {begin_response_token}{answer_suffix}"
594+
595+
# update original prompt with assembled response
596+
prompt += response
597+
if answer_max_tokens is not None:
598+
backend_opts["max_tokens"] = answer_max_tokens
599+
600+
else:
601+
del backend_opts["max_tokens"]
602+
603+
try:
604+
completion_response: Completion = self._client.completions.create(
605+
model=self._hf_model_id, prompt=prompt, **backend_opts
606+
) # type: ignore
607+
except openai.BadRequestError as e:
608+
if openai_ollama_batching_error in e.message:
609+
FancyLogger.get_logger().error(
610+
"If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, "
611+
"your requests will fail since ollama doesn't support batching requests."
612+
)
613+
raise e
614+
615+
response += completion_response.choices[0].text
616+
gen_tok_count += completion_response.usage.completion_tokens
617+
return response, gen_tok_count
618+
470619
def _generate_from_raw(
471620
self,
472621
actions: list[Component | CBlock],
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
vllm.err
2+
vllm.log
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
2+
# Test for OpenAI API served by VLLM
3+
4+
## Requirement
5+
6+
anaconda / miniconda / miniforge.
7+
8+
Make sure to run the test with multiple cores available (e.g. in a cloud instance / cluster job).
9+
Although you may think 1 core is enough,
10+
vllm could get stuck due to deadlock if so.
11+
12+
## Installation
13+
14+
Needs to be done only once.
15+
I creates a new conda environment named "mallea_tbf" only for the purposes of testing or contributing to the think budget-forcing feature.
16+
17+
Run `./install.sh`
18+
19+
## Testing
20+
21+
``` shell
22+
./run_test.sh
23+
```
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
name: mellea_tbf
3+
channels:
4+
- conda-forge
5+
dependencies:
6+
- python=3.12 # note: at the time of writing, xformer (< vllm) has a broken wheel for 3.13. https://github.com/facebookresearch/xformers/issues/740#issuecomment-2753869337
7+
- uv
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash -xe
2+
3+
ENV_NAME=mellea_tbf
4+
conda env remove -y -n $ENV_NAME || true
5+
conda env create -f $(readlink -ef $(dirname $0))/environment.yml
6+
7+
in-conda (){
8+
conda run -n $ENV_NAME $@
9+
}
10+
11+
12+
cd ../../../
13+
in-conda uv pip install -e .
14+
cd -
15+
in-conda uv pip install pre-commit
16+
in-conda uv pip install pytest
17+
in-conda uv pip install vllm==0.10.0
18+
in-conda uv pip install outlines
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
3+
ENV_NAME=mellea_tbf
4+
eval "$(conda shell.bash hook)"
5+
conda activate $ENV_NAME
6+
7+
dir=$(readlink -ef $(dirname $0))
8+
rm $dir/vllm.log $dir/vllm.err
9+
10+
bash $dir/serve.sh &
11+
vllm_pid=$!
12+
13+
trap "kill -SIGINT $vllm_pid ; wait" EXIT
14+
15+
while sleep 1 ; do
16+
if grep -q "Application startup complete." $dir/vllm.err
17+
then
18+
break
19+
fi
20+
done
21+
22+
python test_think_budget_forcing.py
23+
24+
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/bin/bash
2+
3+
# @Masa note:
4+
# the following code is a bash snippet Kristian gave me
5+
# for how to run vllm with lora adapter loaded.
6+
7+
# HF_GRANITE_ALORA_SNAPSHOT=${HF_HOME:-$HOME/.cache/huggingface}
8+
# HF_GRANITE_ALORA_SNAPSHOT+=/hub/
9+
# HF_GRANITE_ALORA_SNAPSHOT+=models--ibm-granite--granite-3.2-8b-alora-requirement-check/
10+
# HF_GRANITE_ALORA_SNAPSHOT+=snapshots/d55a7a7f5796609bc938c5c151a864cfcc6ab54e
11+
12+
# vllm serve ibm-granite/granite-3.2-8b-instruct \
13+
# --enable-lora \
14+
# --lora-modules "{\"name\": \"ibm-granite/granite-3.2-8b-alora-requirement-check\", \"path\": \"${HF_GRANITE_ALORA_SNAPSHOT}\", \"base_model_name\": \"ibm-granite/granite-3.2-8b-instruct\"}" \
15+
# --dtype bfloat16 \
16+
# --max-lora-rank 64 \
17+
# --enable-prefix-caching
18+
19+
# However, in our test, we do not load the alora when we serve.
20+
# In this test, we use the dynamic loading interface from
21+
# https://docs.vllm.ai/en/stable/features/lora.html#dynamically-serving-lora-adapters
22+
23+
# Using this feature requires the following environment variable.
24+
# If you use conda/miniforge,
25+
# this variable must have been set already when you set up the environment.
26+
# see environment.yml.
27+
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
28+
29+
echo "launching a vllm server. Logs are found in $(readlink -ef $(dirname $0))/vllm.log"
30+
# At the time of writing this code, Granite 4.4 vLLM serving did not support prefix-caching
31+
# --enable-prefix-caching \
32+
vllm serve ibm-granite/granite-4.0-tiny-preview \
33+
--dtype bfloat16 \
34+
> $(readlink -ef $(dirname $0))/vllm.log \
35+
2> $(readlink -ef $(dirname $0))/vllm.err
36+
37+
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from mellea import MelleaSession
2+
from mellea.stdlib.base import CBlock, SimpleContext
3+
from mellea.backends.openai import OpenAIBackend
4+
from transformers import AutoTokenizer
5+
import pytest
6+
import os
7+
8+
class TestOpenAIBackend:
9+
model_id = "ibm-granite/granite-4.0-tiny-preview"
10+
backend = OpenAIBackend(
11+
model_id=model_id,
12+
base_url="http://0.0.0.0:8000/v1",
13+
api_key="EMPTY",
14+
)
15+
m = MelleaSession(backend, ctx=SimpleContext())
16+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
17+
18+
def prepare_prmpt_for_math(self, query):
19+
# Preparing prompt for math reasoning tasks
20+
system_prompt = None # Use default of chat template
21+
prompt_suffix = "\nPlease reason step by step, use \n\n to end each step, and put your final answer within \\boxed{}."
22+
23+
if prompt_suffix:
24+
query += prompt_suffix
25+
26+
msg = []
27+
if system_prompt is not None:
28+
msg.append({"role": "system", "content": system_prompt})
29+
30+
msg.append({"role": "user", "content": query})
31+
prompt = self.tokenizer.apply_chat_template(
32+
msg,
33+
tokenize=False,
34+
thinking=True,
35+
add_generation_prompt=True,
36+
)
37+
38+
return prompt
39+
40+
def test_generate_from_raw_small(self):
41+
prompt = "what is 1+1?"
42+
prompt = self.prepare_prmpt_for_math(prompt)
43+
action = CBlock(value=prompt)
44+
results = []
45+
THINK_MAX_TOKENS = 64
46+
ANSWER_MAX_TOKENS = 16
47+
result, gen_tok_cnt = self.m.backend.generate_with_budget_forcing(
48+
action=action,
49+
think_max_tokens=THINK_MAX_TOKENS,
50+
answer_max_tokens=ANSWER_MAX_TOKENS,
51+
start_think_token = "<think>",
52+
end_think_token="</think>",
53+
think_wait_suffix="Wait",
54+
answer_suffix="The final answer is:",
55+
# answer_suffix="",
56+
answer_token="boxed",
57+
)
58+
59+
assert gen_tok_cnt <= 2 * THINK_MAX_TOKENS
60+
61+
62+
def test_generate_from_raw_large(self):
63+
prompt = "what is 1+1?"
64+
prompt = self.prepare_prmpt_for_math(prompt)
65+
action = CBlock(value=prompt)
66+
results = []
67+
THINK_MAX_TOKENS = 1024
68+
ANSWER_MAX_TOKENS = 256
69+
result, gen_tok_cnt = self.m.backend.generate_with_budget_forcing(
70+
action=action,
71+
think_max_tokens=THINK_MAX_TOKENS,
72+
answer_max_tokens=ANSWER_MAX_TOKENS,
73+
start_think_token = "<think>",
74+
end_think_token="</think>",
75+
think_wait_suffix="Wait",
76+
answer_suffix="The final answer is:",
77+
answer_token="boxed",
78+
)
79+
80+
assert gen_tok_cnt >= 0.5 * THINK_MAX_TOKENS
81+
82+
83+
if __name__ == "__main__":
84+
pytest.main(["-s", __file__])

0 commit comments

Comments
 (0)