Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/modelgauge/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import shlex
import subprocess
import time
from typing import List, Optional, Set, Type, TypeVar
from typing import List, Optional, TypeVar

from airrlogger.log_config import get_logger
from tqdm import tqdm

# Type vars helpful in defining templates.
_InT = TypeVar("_InT")
_InT = TypeVar("_InT", bound=type)

logger = get_logger(__name__)

Expand All @@ -20,8 +20,8 @@ def current_timestamp_millis() -> int:
return time.time_ns() // 1_000_000


def get_concrete_subclasses(cls: Type[_InT]) -> Set[Type[_InT]]:
result = set()
def get_concrete_subclasses(cls: _InT) -> set[_InT]:
result: set[_InT] = set()
for subclass in cls.__subclasses__():
if not inspect.isabstract(subclass):
result.add(subclass)
Expand Down
77 changes: 64 additions & 13 deletions src/modelgauge/reasoning_handlers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any

from airrlogger.log_config import get_logger
from pydantic import BaseModel

from modelgauge.general import get_concrete_subclasses
from modelgauge.model_options import ModelOptions
from modelgauge.prompt import TextPrompt
from modelgauge.sut import PromptResponseSUT, SUTResponse
Expand All @@ -17,54 +19,96 @@ class ReasoningRequest(BaseModel):
max_total_tokens: int | None = None # Total number of tokens allowed (thinking + content).


class ThinkingMixin(PromptResponseSUT):
class ReasoningHandler(ABC):
@staticmethod
def _get_concrete_reasoning_suts() -> set[type["ReasoningHandler"]]:
return get_concrete_subclasses(ReasoningHandler)

@staticmethod
def find_match(sut: PromptResponseSUT) -> type["ReasoningHandler"] | None:
reasoning_suts = ReasoningHandler._get_concrete_reasoning_suts()
for rs in reasoning_suts:
if rs.sut_matches(sut):
return rs
return None

@classmethod
def sut_matches(cls, sut) -> bool:
"""Finds a matching reasoning handler for the given SUT. Calling this method will result in 1 SUT call."""
request = sut.translate_text_prompt(
TextPrompt(text="If I have 2 apples and give 1 to my friend, how many apples do I have left?"),
options=ModelOptions(max_tokens=1000),
)
raw_response = sut.evaluate(request)
response = sut.translate_response(request, raw_response)
return cls.response_contains_reasoning(response)

@classmethod
@abstractmethod
def response_contains_reasoning(cls, response: SUTResponse) -> bool:
pass


class ThinkingMixin(ReasoningHandler):
"""
A mixin for SUTs that parses out thinking text from the output.

The output is expected to be in the form: {reasoning text}</think>{content text}.
If max_total_output_tokens is set in ModelOptions, that value will be used in the model call and the content text will be truncated to max_tokens.
Otherwise, max_tokens is used in the model call and everything after </think> is returned as content.

Reasoning should be enabled by the model by default. This mixin does not request reasoning be enabled (yet).
"""

OPEN_TAG = "<think>" # Optional.
CLOSE_TAG = "</think>" # Tag that separates reasoning from content.

def __init__(self, uid, *args, **kwargs):
super().__init__(uid, *args, **kwargs)
self.tokenizer = GeneralTokenizer()
self.separator = "</think>" # Tag that separates reasoning from content.

def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> ReasoningRequest:
@classmethod
def response_contains_reasoning(cls, response: SUTResponse) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like SUTResponse's responsibility to report whether it contains reasoning.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your thinking there. But answering that question depends on the type of reasoning pattern that the model is following. And I think matching against all know reasoning patterns is a lot to ask of a SUTResponse object that is mainly intended to function as a bag of data.

Also, I think it's cleaner if all reasoning functionality (matching, parsing, request formatting) is handled by the same object.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing as Roger. But I think the matching can happen in the SUT. How about we have a subclass of SUTResponse that is a ReasoningSUTResponse which has a couple of new properties?

Off the top of my head, I'd say SUTResponse's text should be the raw response, as that's what the SUT actually output. Then we add a couple of properties that could be reasoning specific, like reasoning_text and response_text. Or you could add a property to SUTResponse that is something like clean_text or filtered_text or user_facing_text, which in the superclass is the same thing as text and in the ReasoningSUTReponse returns the stripped output.

That preserves the generality of PromptResponseSUT for a variety of use cases, but allows benchmarks to ask for the cleaned text they need for evaluation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that, but it feels overly complicated. I think people (Kurt, workstream folks) would struggle keeping track of when to look at text and when to look at response_text. I think it would also make maintenance more difficult (every time you look at the SUTResponse obejct, you have to recall that text only matters in one case and reasoning_text matters in another). I think this has a lot of downstream repercussions for understandability.

I don't see the advantages here.

return cls.OPEN_TAG in response.text or cls.CLOSE_TAG in response.text

def translate_text_prompt(
self, sut: PromptResponseSUT, prompt: TextPrompt, options: ModelOptions
) -> ReasoningRequest:
max_total_tokens = options.max_total_output_tokens
if max_total_tokens is None:
max_total_tokens = options.max_tokens
max_content_tokens = options.max_tokens

# Replace max_tokens in raw request with the max total tokens.
options.max_tokens = max_total_tokens
request = super().translate_text_prompt(prompt, options)
request = sut.translate_text_prompt(prompt, options)
return ReasoningRequest(
request=request,
max_content_tokens=max_content_tokens,
max_total_tokens=max_total_tokens,
)

def evaluate(self, request: ReasoningRequest) -> Any:
return super().evaluate(request.request) # type: ignore
def evaluate(self, sut: PromptResponseSUT, request: ReasoningRequest) -> Any:
return sut._evaluate(request.request) # type: ignore

def translate_response(self, request: ReasoningRequest, response: Any) -> SUTResponse:
text = super().translate_response(request.request, response).text # type: ignore
def translate_response(self, sut: PromptResponseSUT, request: ReasoningRequest, response: Any) -> SUTResponse:
text = sut._translate_response(request.request, response).text # type: ignore

think_close = text.find(self.separator)
think_close = text.rfind(self.CLOSE_TAG)
if think_close == -1:
# no closing tag: everything is thinking text
return SUTResponse(text="")
return SUTResponse(text="", reasoning=self.trim_tokens(text))

reasoning = text[: think_close + len(self.separator)].strip()
content = text[think_close + len(self.separator) :].strip()
reasoning = text[: think_close + len(self.CLOSE_TAG)].strip()
content = text[think_close + len(self.CLOSE_TAG) :].strip()
self.warn_edge_cases(content, reasoning, request)

reasoning = self.trim_tokens(reasoning)

# Truncate content
if request.max_content_tokens is not None:
content = self.tokenizer.truncate(content, request.max_content_tokens)
return SUTResponse(text=content)
return SUTResponse(text=content, reasoning=reasoning)

def warn_edge_cases(self, content, reasoning, request):
if request.max_total_tokens is None:
Expand All @@ -77,3 +121,10 @@ def warn_edge_cases(self, content, reasoning, request):
logger.warning(
f"SUT {self.uid} reasoning likely ate into the token budget of the actual output. Consider increasing max_total_output_tokens."
)

def trim_tokens(self, text: str) -> str:
if text.startswith(self.OPEN_TAG):
text = text[len(self.OPEN_TAG) :]
if text.endswith(self.CLOSE_TAG):
text = text[: -len(self.CLOSE_TAG)]
return text
28 changes: 25 additions & 3 deletions src/modelgauge/sut.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SUTResponse(BaseModel):
"""The data that came out of the SUT."""

text: str
reasoning: Optional[str] = None
top_logprobs: Optional[Sequence[TopTokens]] = None
"""For each position, list the probabilities for each of the most likely tokens.

Expand Down Expand Up @@ -58,14 +59,23 @@ class PromptResponseSUT(SUT, Readyable):
Abstract base class that provides an interface to any SUT that is designed for handling a single-turn.
"""

def __init__(self, uid: str):
super().__init__(uid)
self.reasoning_handler: Optional[Type[ReasoningHandler]] = ReasoningHandler.sut_matches(self)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two qualms with this:

  1. It will be null most of the time and only used in specialized versions of a SUT class, so it should probably be a subclass like a ReasoningPromptResponseSUT.
  2. It calls a third-party service on object creation, which has a side effect the caller may not be aware of, and which is only necessary if we don't know whether a SUT is reasoning. I think we know this for sure in a lot of cases, without having to test the SUT. But that could be wrong.

So I'd suggest doing the sut_matches check in the factory that instantiates the SUT, and either passing in the handler to this constructor or a flag for the constructor to create the handler, without the constructor doing the probing.

Alternatively, make the reasoning handler something more abstract that can be common to all SUTs. In reasoning SUTs, it would do the response parsing; in other SUTs, it would be a no-op. Then every SUT would behave the same. That may overcomplicate things, and so a ReasoningPromptResponseSUT subclass with a handler may be cleaner.

Alternatively, for SUTs we know are reasoning SUTs, there could be an option to bypass the check, like

def __init__(self, uid: str, is_reasoning: bool | None = None):
    if is_reasoning is None:
        # do the check
    elif is_reasoning is False:
        self.reasoning_handler = None
    else:
        self.reasoning_handler = ReasoningHandler()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'd suggest doing the sut_matches check in the factory that instantiates the SUT, and either passing in the handler to this constructor or a flag for the constructor to create the handler, without the constructor doing the probing.

I believe this is more so similar to my original PR (which I actually prefer). I could implement this approach, but I know @wpietri had qualms about the factory instantiating a SUT, doing a reasoning check, and then possibly modifying the SUT.

Alternatively, make the reasoning handler something more abstract that can be common to all SUTs. In reasoning SUTs, it would do the response parsing; in other SUTs, it would be a no-op.

I think this would be difficult if there are multiple reasoning handlers.

Alternatively, for SUTs we know are reasoning SUTs, there could be an option to bypass the check, like

The scope of this first task I'm tackling only includes SUTs that we do not have prior information about. But that's a good idea for later!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My qualms depend a lot on how the modification happens. If it's just setting a flag or something, I think that's ok. But I have a bias toward immutability, so I'd rather we instantiate a whole and complete thing. E.g., the thing that creates the SUT could create a naive version that assumes no reasoning, check to see if there's reasoning in the output, and then create a new SUT instance that handles reasoning.

One thing that strikes me is that SUTs are not going to change their reasoning-ness on the fly, so SUT instantiation isn't the only time to collect the information of what's a reasoning SUT and what isn't. If we make a big lookup table of what has reasoning and what doesn't, which we have to do anyway to properly categorize our benchmark results, we would only have to do run-time checks for novel-to-us SUTs. If the look-it-up approach is our default case and the run-time check becomes the unusual one, could that orientation make the code clearer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the thing that creates the SUT could create a naive version that assumes no reasoning, check to see if there's reasoning in the output, and then create a new SUT instance that handles reasoning.

Isn't that what I did in the first PR? Does that mean your only issue was with the dynamic class mixing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if I do the check reasoning + instantiate a new SUT with a reasoning handler in the factory... is there anyway to do it that avoids both dynamic class mixing from the first PR and the circular dependencies in this PR? @rogthefrog @wpietri

Copy link
Copy Markdown
Contributor

@rogthefrog rogthefrog Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in DM, I think the circular dependency thing is for a broader discussion and not in scope.

And fwiw I don't think the initial implementation is too problematic.


def run_readiness_check(self) -> ReadyResponse:
raw_request = self.translate_text_prompt(_READINESS_CHECK_TEXT_PROMPT, options=_READINESS_CHECK_SUT_OPTIONS)
raw_response = self.evaluate(raw_request)
response = self.translate_response(raw_request, raw_response)
return ReadyResponse(is_ready=response.text is not None, response=response)

@not_implemented
def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions):
if self.reasoning_handler is not None:
return self.reasoning_handler.translate_text_prompt(self, prompt, options)
return self._translate_text_prompt(prompt, options)

@not_implemented
def _translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions):
"""Convert the prompt + SUT options into the SUT's native representation.

This method must be implemented if the SUT accepts text prompts.
Expand All @@ -80,12 +90,24 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions):
"""
raise NotImplementedError(f"SUT {self.__class__.__name__} does not implement translate_chat_prompt.")

@abstractmethod
def evaluate(self, request):
"""Evaluate this SUT on the native request."""
pass
if self.reasoning_handler is not None:
return self.reasoning_handler.evaluate(self, request)
return self._evaluate(request)

@abstractmethod
def _evaluate(self, request):
"""Evaluate this SUT on the native request."""
pass

def translate_response(self, request, response) -> SUTResponse:
"""Convert the native response into a form all Tests can process."""
if self.reasoning_handler is not None:
return self.reasoning_handler._translate_response(self, request, response)
return self._translate_response(request, response)

@abstractmethod
def _translate_response(self, request, response) -> SUTResponse:
"""Convert the native response into a form all Tests can process."""
pass
41 changes: 1 addition & 40 deletions src/modelgauge/suts/huggingface_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.prompt import TextPrompt, ChatPrompt
from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
Expand Down Expand Up @@ -186,14 +185,6 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
)


@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
class HuggingFaceChatCompletionDedicatedThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionDedicatedSUT):
"""
A SUT that excludes the reasoning from model output.
Reasoning must be seperated from normal output with a </think> tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)
"""


@modelgauge_sut(capabilities=[AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT):
"""A SUT hosted by an inference provider on huggingface."""
Expand Down Expand Up @@ -231,14 +222,6 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
)


@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
class HuggingFaceChatCompletionServerlessThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionServerlessSUT):
"""
A SUT that excludes the reasoning from model output.
Reasoning must be seperated from normal output with a </think> tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)
"""


HF_SECRET = InjectSecret(HuggingFaceInferenceToken)

SUTS.register(
Expand Down Expand Up @@ -276,29 +259,7 @@ class HuggingFaceChatCompletionServerlessThinkingSUT(ThinkingMixin, HuggingFaceC
None,
HF_SECRET,
)
# Special thinking dedicated SUTs
SUTS.register(
HuggingFaceChatCompletionDedicatedThinkingSUT,
"nvidia-nemotron-3-nano-30b-a-thinking-excluded-hf",
"nvidia-nemotron-3-nano-30b-a-mia",
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
HF_SECRET,
)
SUTS.register(
HuggingFaceChatCompletionDedicatedThinkingSUT,
"PrimeIntellect-INTELLECT-3-thinking-excluded-hf",
"intellect-3-uqs",
"PrimeIntellect/INTELLECT-3",
HF_SECRET,
)
# Special thinking serverless SUTs
SUTS.register(
HuggingFaceChatCompletionServerlessThinkingSUT,
"moonshotai/Kimi-K2.5-together-thinking-excluded-hf",
"moonshotai/Kimi-K2.5",
"together",
HF_SECRET,
)

# Register serverless SUTs.
SUTS.register(
HuggingFaceChatCompletionServerlessSUT,
Expand Down
8 changes: 0 additions & 8 deletions src/modelgauge/suts/together_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt
from modelgauge.prompt_formatting import format_chat
from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities
Expand Down Expand Up @@ -271,11 +270,6 @@ def translate_response(self, request: TogetherChatRequest, response: TogetherCha
return SUTResponse(text=text, top_logprobs=logprobs)


@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
class TogetherThinkingSUT(ThinkingMixin, TogetherChatSUT):
"""SUT that preforms reasoning like deepseek-r1"""


@modelgauge_sut(
capabilities=[
AcceptsTextPrompt,
Expand Down Expand Up @@ -382,5 +376,3 @@ def evaluate(self, request: TogetherChatRequest) -> TogetherChatResponse:
}
for uid, model_name in DEDICATED_CHAT_MODELS.items():
SUTS.register(TogetherDedicatedChatSUT, uid, model_name, InjectSecret(TogetherApiKey))

SUTS.register(TogetherThinkingSUT, "deepseek-R1-thinking", "deepseek-ai/DeepSeek-R1", InjectSecret(TogetherApiKey))
11 changes: 6 additions & 5 deletions tests/modelbench_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def runner(self, run_dir):
def invoke(command, args=None, **kwargs):
args = list(args or [])
full_args = ["--run-path", run_dir] + args
print(command, full_args, kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to leave this in?

return runner.invoke(command, full_args, **kwargs)

return invoke
Expand All @@ -316,7 +317,7 @@ def invoke(command, args=None, **kwargs):
],
# TODO add more locales as we add support for them
)
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
@pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_benchmark_basic_run_produces_json(
self,
monkeypatch,
Expand Down Expand Up @@ -396,7 +397,7 @@ def test_benchmark_basic_run_produces_json(
],
# TODO add more locales as we add support for them
)
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay;mt=500;t=0.3"])
@pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_benchmark_multiple_suts_produces_json(
self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, run_dir, monkeypatch
):
Expand Down Expand Up @@ -546,7 +547,7 @@ def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_run_ben
#
# benchmark_arg = mock_score_benchmarks.call_args.args[0][0]
# assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmark)
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
@pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid):
_ = runner(cli, ["benchmark", "general", "--sut", sut_uid])

Expand All @@ -555,14 +556,14 @@ def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid):
assert benchmark_arg.locale == EN_US
assert benchmark_arg.prompt_set == "demo"

@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
@pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner, sut_uid):
result = runner(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid])
assert result.exit_code == 2
assert "Invalid value for '--prompt-set'" in result.output

@pytest.mark.parametrize("prompt_set", GENERAL_PROMPT_SETS.keys())
@pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
@pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_run_benchmarks, prompt_set, sut_uid):
_ = runner(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid])

Expand Down
Loading
Loading