44import datetime
55import inspect
66import json
7+ import re
78from collections .abc import Callable
89from enum import Enum
910from typing import TYPE_CHECKING , Any
1314import requests
1415from huggingface_hub import snapshot_download
1516from openai .types .chat import ChatCompletion
16- from openai .types .completion import Completion
17+ from openai .types .completion import Completion , CompletionUsage
1718
1819import mellea .backends .model_ids as model_ids
1920from mellea .backends import BaseModelSubclass
3738)
3839from mellea .stdlib .chat import Message
3940from mellea .stdlib .requirement import ALoraRequirement , LLMaJRequirement , Requirement
40- import re
4141
4242if TYPE_CHECKING :
4343 from transformers .tokenization_utils import PreTrainedTokenizer
@@ -482,7 +482,7 @@ def generate_with_budget_forcing(
482482 answer_suffix : str = "The final answer is:" ,
483483 answer_regex : str = "boxed" ,
484484 model_options : dict | None = None ,
485- ) -> list [ ModelOutputThunk ]:
485+ ) -> tuple [ str , int ]:
486486 """Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structured in the following form: '<think> ... </think> summary answer'
487487 The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
488488 This implementation tries to follow the key outlines in the paper while ensuring stable and fail-safe operation.
@@ -520,7 +520,6 @@ def generate_with_budget_forcing(
520520 model_opts , is_chat_context = False
521521 )
522522 # Generate thinking portion
523- max_tok_thd = 1.0
524523 # backend_opts["echo"] = True
525524 # backend_opts["logprobs"] = 1
526525 backend_opts ["n" ] = 1
@@ -531,7 +530,6 @@ def generate_with_budget_forcing(
531530
532531 # think block indefinite multi-step operation to satisfy user's budget
533532 while True :
534-
535533 if rem_toks <= 0 : # zero-think case
536534 break
537535
@@ -540,7 +538,7 @@ def generate_with_budget_forcing(
540538
541539 backend_opts ["max_tokens" ] = rem_toks
542540 try :
543- completion_response : Completion = self ._client .completions .create (
541+ completion_response = self ._client .completions .create (
544542 model = self ._hf_model_id , prompt = curr_prompt , ** backend_opts
545543 ) # type: ignore
546544 except openai .BadRequestError as e :
@@ -551,6 +549,8 @@ def generate_with_budget_forcing(
551549 )
552550 raise e
553551
552+ # Necessary for type checker.
553+ assert isinstance (completion_response .usage , CompletionUsage )
554554 gen_tok_count += completion_response .usage .completion_tokens
555555 rem_toks = think_max_tokens - gen_tok_count
556556 response = completion_response .choices [0 ].text
@@ -586,9 +586,7 @@ def generate_with_budget_forcing(
586586 # The think block, but we will use relaxed requirement of finding any answer in the model's response.
587587 # Consider a strict structural approach in the future.
588588 # e.g.
589- # ans_portion = response.split(end_think_token)[-1]
590- # if answer_regex in ans_portion:
591- # return response, gen_tok_count
589+ # answer_blk = response.split(end_think_token)[-1]
592590
593591 # Check if answer in response
594592 matches = re .findall (answer_regex , response , re .DOTALL )
@@ -611,10 +609,10 @@ def generate_with_budget_forcing(
611609 backend_opts ["max_tokens" ] = answer_max_tokens
612610
613611 else :
614- backend_opts .pop ("max_tokens" , None ) # generate unconditionally
612+ backend_opts .pop ("max_tokens" , None ) # generate unconditionally
615613
616614 try :
617- completion_response : Completion = self ._client .completions .create (
615+ completion_response = self ._client .completions .create (
618616 model = self ._hf_model_id , prompt = prompt , ** backend_opts
619617 ) # type: ignore
620618 except openai .BadRequestError as e :
@@ -625,6 +623,8 @@ def generate_with_budget_forcing(
625623 )
626624 raise e
627625
626+ # Necessary for type checker.
627+ assert isinstance (completion_response .usage , CompletionUsage )
628628 response += completion_response .choices [0 ].text
629629 gen_tok_count += completion_response .usage .completion_tokens
630630 return response , gen_tok_count
0 commit comments