Skip to content

Commit fda0768

Browse files
author
Yousef El-Kurdi
committed
resolved type checking errors
1 parent 9385af8 commit fda0768

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

mellea/backends/openai.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import datetime
55
import inspect
66
import json
7+
import re
78
from collections.abc import Callable
89
from enum import Enum
910
from typing import TYPE_CHECKING, Any
@@ -13,7 +14,7 @@
1314
import requests
1415
from huggingface_hub import snapshot_download
1516
from openai.types.chat import ChatCompletion
16-
from openai.types.completion import Completion
17+
from openai.types.completion import Completion, CompletionUsage
1718

1819
import mellea.backends.model_ids as model_ids
1920
from mellea.backends import BaseModelSubclass
@@ -37,7 +38,6 @@
3738
)
3839
from mellea.stdlib.chat import Message
3940
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
40-
import re
4141

4242
if TYPE_CHECKING:
4343
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -482,7 +482,7 @@ def generate_with_budget_forcing(
482482
answer_suffix: str = "The final answer is:",
483483
answer_regex: str = "boxed",
484484
model_options: dict | None = None,
485-
) -> list[ModelOutputThunk]:
485+
) -> tuple[str, int]:
486486
"""Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structured in the following form: '<think> ... </think> summary answer'
487487
The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
488488
This implementation tries to follow the key outlines in the paper while ensuring stable and fail-safe operation.
@@ -520,7 +520,6 @@ def generate_with_budget_forcing(
520520
model_opts, is_chat_context=False
521521
)
522522
# Generate thinking portion
523-
max_tok_thd = 1.0
524523
# backend_opts["echo"] = True
525524
# backend_opts["logprobs"] = 1
526525
backend_opts["n"] = 1
@@ -531,7 +530,6 @@ def generate_with_budget_forcing(
531530

532531
# think block indefinite multi-step operation to satisfy user's budget
533532
while True:
534-
535533
if rem_toks <= 0: # zero-think case
536534
break
537535

@@ -540,7 +538,7 @@ def generate_with_budget_forcing(
540538

541539
backend_opts["max_tokens"] = rem_toks
542540
try:
543-
completion_response: Completion = self._client.completions.create(
541+
completion_response = self._client.completions.create(
544542
model=self._hf_model_id, prompt=curr_prompt, **backend_opts
545543
) # type: ignore
546544
except openai.BadRequestError as e:
@@ -551,6 +549,8 @@ def generate_with_budget_forcing(
551549
)
552550
raise e
553551

552+
# Necessary for type checker.
553+
assert isinstance(completion_response.usage, CompletionUsage)
554554
gen_tok_count += completion_response.usage.completion_tokens
555555
rem_toks = think_max_tokens - gen_tok_count
556556
response = completion_response.choices[0].text
@@ -586,9 +586,7 @@ def generate_with_budget_forcing(
586586
# The think block, but we will use relaxed requirement of finding any answer in the model's response.
587587
# Consider a strict structural approach in the future.
588588
# e.g.
589-
# ans_portion = response.split(end_think_token)[-1]
590-
# if answer_regex in ans_portion:
591-
# return response, gen_tok_count
589+
# answer_blk = response.split(end_think_token)[-1]
592590

593591
# Check if answer in response
594592
matches = re.findall(answer_regex, response, re.DOTALL)
@@ -611,10 +609,10 @@ def generate_with_budget_forcing(
611609
backend_opts["max_tokens"] = answer_max_tokens
612610

613611
else:
614-
backend_opts.pop("max_tokens", None) # generate unconditionally
612+
backend_opts.pop("max_tokens", None) # generate unconditionally
615613

616614
try:
617-
completion_response: Completion = self._client.completions.create(
615+
completion_response = self._client.completions.create(
618616
model=self._hf_model_id, prompt=prompt, **backend_opts
619617
) # type: ignore
620618
except openai.BadRequestError as e:
@@ -625,6 +623,8 @@ def generate_with_budget_forcing(
625623
)
626624
raise e
627625

626+
# Necessary for type checker.
627+
assert isinstance(completion_response.usage, CompletionUsage)
628628
response += completion_response.choices[0].text
629629
gen_tok_count += completion_response.usage.completion_tokens
630630
return response, gen_tok_count

0 commit comments

Comments
 (0)