Skip to content

Commit 9385af8

Browse files
author
Yousef El-Kurdi
committed
adds zero-think case
1 parent bd503cf commit 9385af8

File tree

2 files changed

+74
-42
lines changed

2 files changed

+74
-42
lines changed

mellea/backends/openai.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from mellea.stdlib.chat import Message
3939
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
40+
import re
4041

4142
if TYPE_CHECKING:
4243
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -471,21 +472,21 @@ def generate_with_budget_forcing(
471472
self,
472473
action: CBlock,
473474
*,
474-
think_max_tokens: int = 3072,
475+
think_max_tokens: int = 4096,
475476
answer_max_tokens: int | None = None,
476-
start_think_token: str | None = "<think>",
477-
end_think_token: str | None = "</think>",
478-
begin_response_token: str | None = None,
479-
end_response_token: str | None = None,
480-
think_wait_suffix: str | None = None,
481-
answer_suffix: str | None = "The final answer is:",
482-
answer_token: str | None = "boxed",
477+
start_think_token: str = "<think>",
478+
end_think_token: str = "</think>",
479+
begin_response_token: str = "",
480+
end_response_token: str = "",
481+
think_wait_suffix: str = "",
482+
answer_suffix: str = "The final answer is:",
483+
answer_regex: str = "boxed",
483484
model_options: dict | None = None,
484485
) -> list[ModelOutputThunk]:
485-
"""Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structued in the following form: '<think> ... </think> summary answer'
486+
"""Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structured in the following form: '<think> ... </think> summary answer'
486487
The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
487488
This implementation tries to follow the key outlines in the paper while ensuring stable and fail-safe operation.
488-
This is performed via multi-step generation. The model will be called multiple times until requirements are met, in other words, the response will be assembeled conditionally.
489+
This is performed via multi-step generation. The model will be called multiple times until requirements are met, in other words, the response will be assembled conditionally.
489490
490491
Args:
491492
think_max_tokens: Budget in number of tokens allocated for the think block
@@ -496,7 +497,7 @@ def generate_with_budget_forcing(
496497
end_response_token: Used by certain models, string indicating end of response block, e.g. "</response>", default None
497498
think_wait_suffix: String to append to force continued thinking, e.g. "\nWait" if set to None we will not force additional thinking. Use None for upper-bound budget case
498499
answer_suffix: String to append to force a final answer
499-
answer_token: Token that indicates an answer is generated
500+
answer_regex: Answer regex which indicates an answer is generated
500501
501502
Assumptions:
502503
- The chat template is applied on prompt, with think mode enabled
@@ -511,48 +512,61 @@ def generate_with_budget_forcing(
511512

512513
responses = []
513514
prompt = self.formatter.print(action)
514-
if start_think_token is not None:
515+
if start_think_token:
515516
prompt += start_think_token
516517
responses.append(start_think_token)
518+
517519
backend_opts = self._make_backend_specific_and_remove(
518520
model_opts, is_chat_context=False
519521
)
520522
# Generate thinking portion
521-
max_tok_thd = 0.8
522-
backend_opts["max_tokens"] = think_max_tokens
523+
max_tok_thd = 1.0
523524
# backend_opts["echo"] = True
524525
# backend_opts["logprobs"] = 1
525526
backend_opts["n"] = 1
527+
rem_toks = think_max_tokens
526528
gen_tok_count = 0
527529
curr_prompt = prompt
528530
min_step_len = 10 # minimum character length of step to be considered valid
529531

530532
# think block indefinite multi-step operation to satisfy user's budget
531533
while True:
534+
535+
if rem_toks <= 0: # zero-think case
536+
break
537+
538+
if rem_toks <= min_step_len: # minimum step length reached
539+
break
540+
541+
backend_opts["max_tokens"] = rem_toks
532542
try:
533543
completion_response: Completion = self._client.completions.create(
534544
model=self._hf_model_id, prompt=curr_prompt, **backend_opts
535545
) # type: ignore
536546
except openai.BadRequestError as e:
537547
if openai_ollama_batching_error in e.message:
538548
FancyLogger.get_logger().error(
539-
"If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, "
549+
"If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
540550
"your requests will fail since ollama doesn't support batching requests."
541551
)
542552
raise e
543553

544554
gen_tok_count += completion_response.usage.completion_tokens
555+
rem_toks = think_max_tokens - gen_tok_count
545556
response = completion_response.choices[0].text
546-
if think_wait_suffix is None:
557+
558+
if think_wait_suffix == "":
559+
# non-strict budget form
547560
responses.append(response)
548561
break
549562

550-
if gen_tok_count >= max_tok_thd * think_max_tokens:
563+
if rem_toks <= 0:
551564
responses.append(response)
552565
break
553566

554567
else:
555-
step = response.split(end_think_token)[0]
568+
if end_think_token:
569+
step = response.split(end_think_token)[0]
556570
# model fails to produce thoughts, let's exit
557571
if len(step.strip()) <= min_step_len:
558572
responses.append(response)
@@ -564,11 +578,7 @@ def generate_with_budget_forcing(
564578
curr_prompt += step
565579

566580
response = "".join(responses)
567-
### debug obtaining final answer
568-
# response = response.split(end_think_token)[0]
569-
# response = response.replace(answer_token, "")
570-
###
571-
if answer_token is None or answer_suffix is None:
581+
if answer_regex is None or answer_suffix is None:
572582
return response, gen_tok_count
573583

574584
# Now get a final answer if we need to
@@ -577,28 +587,31 @@ def generate_with_budget_forcing(
577587
# Consider a strict structural approach in the future.
578588
# e.g.
579589
# ans_portion = response.split(end_think_token)[-1]
580-
# if answer_token in ans_portion:
590+
# if answer_regex in ans_portion:
581591
# return response, gen_tok_count
582592

583-
if answer_token in response:
593+
# Check if answer in response
594+
matches = re.findall(answer_regex, response, re.DOTALL)
595+
if len(matches) > 0:
584596
return response, gen_tok_count
585597

586598
# Answer is not in response, let's force an answer
587-
if end_think_token not in response:
588-
response = (
589-
f"{response} {end_think_token}{begin_response_token} {answer_suffix}"
590-
)
599+
if end_think_token and end_think_token not in response:
600+
response += f" {end_think_token}"
591601

592-
else:
593-
response = f"{response} {begin_response_token}{answer_suffix}"
602+
if begin_response_token and begin_response_token not in response:
603+
response += f" {begin_response_token}"
594604

595-
# update original prompt with assembled response
605+
if answer_suffix:
606+
response += f" {answer_suffix}"
607+
608+
# update original prompt with assembled response
596609
prompt += response
597610
if answer_max_tokens is not None:
598611
backend_opts["max_tokens"] = answer_max_tokens
599612

600613
else:
601-
del backend_opts["max_tokens"]
614+
backend_opts.pop("max_tokens", None) # generate unconditionally
602615

603616
try:
604617
completion_response: Completion = self._client.completions.create(
@@ -607,7 +620,7 @@ def generate_with_budget_forcing(
607620
except openai.BadRequestError as e:
608621
if openai_ollama_batching_error in e.message:
609622
FancyLogger.get_logger().error(
610-
"If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, "
623+
"If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
611624
"your requests will fail since ollama doesn't support batching requests."
612625
)
613626
raise e

test/backends/test_think_budget_forcing/test_think_budget_forcing.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def prepare_prmpt_for_math(self, query):
3737

3838
return prompt
3939

40-
def test_generate_from_raw_small(self):
40+
def test_think_small(self):
4141
prompt = "what is 1+1?"
4242
prompt = self.prepare_prmpt_for_math(prompt)
4343
action = CBlock(value=prompt)
4444
results = []
45-
THINK_MAX_TOKENS = 64
46-
ANSWER_MAX_TOKENS = 16
45+
THINK_MAX_TOKENS = 1024
46+
ANSWER_MAX_TOKENS = 256
4747
result, gen_tok_cnt = self.m.backend.generate_with_budget_forcing(
4848
action=action,
4949
think_max_tokens=THINK_MAX_TOKENS,
@@ -53,19 +53,19 @@ def test_generate_from_raw_small(self):
5353
think_wait_suffix="Wait",
5454
answer_suffix="The final answer is:",
5555
# answer_suffix="",
56-
answer_token="boxed",
56+
answer_regex= r"\\boxed{.*?}",
5757
)
5858

5959
assert gen_tok_cnt <= 2 * THINK_MAX_TOKENS
6060

6161

62-
def test_generate_from_raw_large(self):
62+
def test_think_large(self):
6363
prompt = "what is 1+1?"
6464
prompt = self.prepare_prmpt_for_math(prompt)
6565
action = CBlock(value=prompt)
6666
results = []
67-
THINK_MAX_TOKENS = 1024
68-
ANSWER_MAX_TOKENS = 256
67+
THINK_MAX_TOKENS = 2048
68+
ANSWER_MAX_TOKENS = 512
6969
result, gen_tok_cnt = self.m.backend.generate_with_budget_forcing(
7070
action=action,
7171
think_max_tokens=THINK_MAX_TOKENS,
@@ -74,11 +74,30 @@ def test_generate_from_raw_large(self):
7474
end_think_token="</think>",
7575
think_wait_suffix="Wait",
7676
answer_suffix="The final answer is:",
77-
answer_token="boxed",
77+
answer_regex=r"\\boxed{.*?}",
7878
)
7979

8080
assert gen_tok_cnt >= 0.5 * THINK_MAX_TOKENS
8181

8282

83+
def test_zero_think(self):
84+
prompt = "what is 1+1?"
85+
prompt = self.prepare_prmpt_for_math(prompt)
86+
action = CBlock(value=prompt)
87+
results = []
88+
THINK_MAX_TOKENS = 0
89+
ANSWER_MAX_TOKENS = 512
90+
result, gen_tok_cnt = self.m.backend.generate_with_budget_forcing(
91+
action=action,
92+
think_max_tokens=THINK_MAX_TOKENS,
93+
answer_max_tokens=ANSWER_MAX_TOKENS,
94+
start_think_token = "",
95+
end_think_token="<think> Okay, I think I have finished thinking. </think>",
96+
think_wait_suffix="",
97+
answer_suffix="The final answer is:",
98+
)
99+
100+
assert gen_tok_cnt >= 0.5 * THINK_MAX_TOKENS
101+
83102
if __name__ == "__main__":
84103
pytest.main(["-s", __file__])

0 commit comments

Comments
 (0)