diff --git a/src/modelgauge/general.py b/src/modelgauge/general.py
index 5400d28b..866196e1 100644
--- a/src/modelgauge/general.py
+++ b/src/modelgauge/general.py
@@ -5,13 +5,13 @@
import shlex
import subprocess
import time
-from typing import List, Optional, Set, Type, TypeVar
+from typing import List, Optional, TypeVar
from airrlogger.log_config import get_logger
from tqdm import tqdm
# Type vars helpful in defining templates.
-_InT = TypeVar("_InT")
+_InT = TypeVar("_InT", bound=type)
logger = get_logger(__name__)
@@ -20,8 +20,8 @@ def current_timestamp_millis() -> int:
return time.time_ns() // 1_000_000
-def get_concrete_subclasses(cls: Type[_InT]) -> Set[Type[_InT]]:
- result = set()
+def get_concrete_subclasses(cls: _InT) -> set[_InT]:
+ result: set[_InT] = set()
for subclass in cls.__subclasses__():
if not inspect.isabstract(subclass):
result.add(subclass)
diff --git a/src/modelgauge/reasoning_handlers.py b/src/modelgauge/reasoning_handlers.py
index 95ec0e41..3d2b8ee4 100644
--- a/src/modelgauge/reasoning_handlers.py
+++ b/src/modelgauge/reasoning_handlers.py
@@ -1,8 +1,10 @@
+from abc import ABC, abstractmethod
from typing import Any
from airrlogger.log_config import get_logger
from pydantic import BaseModel
+from modelgauge.general import get_concrete_subclasses
from modelgauge.model_options import ModelOptions
from modelgauge.prompt import TextPrompt
from modelgauge.sut import PromptResponseSUT, SUTResponse
@@ -17,21 +19,61 @@ class ReasoningRequest(BaseModel):
max_total_tokens: int | None = None # Total number of tokens allowed (thinking + content).
-class ThinkingMixin(PromptResponseSUT):
+class ReasoningHandler(ABC):
+ @staticmethod
+ def _get_concrete_reasoning_suts() -> set[type["ReasoningHandler"]]:
+ return get_concrete_subclasses(ReasoningHandler)
+
+ @staticmethod
+ def find_match(sut: PromptResponseSUT) -> type["ReasoningHandler"] | None:
+ reasoning_suts = ReasoningHandler._get_concrete_reasoning_suts()
+ for rs in reasoning_suts:
+ if rs.sut_matches(sut):
+ return rs
+ return None
+
+ @classmethod
+ def sut_matches(cls, sut) -> bool:
+ """Finds a matching reasoning handler for the given SUT. Calling this method will result in 1 SUT call."""
+ request = sut.translate_text_prompt(
+ TextPrompt(text="If I have 2 apples and give 1 to my friend, how many apples do I have left?"),
+ options=ModelOptions(max_tokens=1000),
+ )
+ raw_response = sut.evaluate(request)
+ response = sut.translate_response(request, raw_response)
+ return cls.response_contains_reasoning(response)
+
+ @classmethod
+ @abstractmethod
+ def response_contains_reasoning(cls, response: SUTResponse) -> bool:
+ pass
+
+
+class ThinkingMixin(ReasoningHandler):
"""
A mixin for SUTs that parses out thinking text from the output.
The output is expected to be in the form: {reasoning text}{content text}.
If max_total_output_tokens is set in ModelOptions, that value will be used in the model call and the content text will be truncated to max_tokens.
Otherwise, max_tokens is used in the model call and everything after is returned as content.
+
+ Reasoning should be enabled by the model by default. This mixin does not request reasoning be enabled (yet).
"""
+ OPEN_TAG = "" # Optional.
+ CLOSE_TAG = "" # Tag that separates reasoning from content.
+
def __init__(self, uid, *args, **kwargs):
super().__init__(uid, *args, **kwargs)
self.tokenizer = GeneralTokenizer()
- self.separator = "" # Tag that separates reasoning from content.
- def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> ReasoningRequest:
+ @classmethod
+ def response_contains_reasoning(cls, response: SUTResponse) -> bool:
+ return cls.OPEN_TAG in response.text or cls.CLOSE_TAG in response.text
+
+ def translate_text_prompt(
+ self, sut: PromptResponseSUT, prompt: TextPrompt, options: ModelOptions
+ ) -> ReasoningRequest:
max_total_tokens = options.max_total_output_tokens
if max_total_tokens is None:
max_total_tokens = options.max_tokens
@@ -39,32 +81,34 @@ def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> Re
# Replace max_tokens in raw request with the max total tokens.
options.max_tokens = max_total_tokens
- request = super().translate_text_prompt(prompt, options)
+ request = sut.translate_text_prompt(prompt, options)
return ReasoningRequest(
request=request,
max_content_tokens=max_content_tokens,
max_total_tokens=max_total_tokens,
)
- def evaluate(self, request: ReasoningRequest) -> Any:
- return super().evaluate(request.request) # type: ignore
+ def evaluate(self, sut: PromptResponseSUT, request: ReasoningRequest) -> Any:
+ return sut._evaluate(request.request) # type: ignore
- def translate_response(self, request: ReasoningRequest, response: Any) -> SUTResponse:
- text = super().translate_response(request.request, response).text # type: ignore
+ def translate_response(self, sut: PromptResponseSUT, request: ReasoningRequest, response: Any) -> SUTResponse:
+ text = sut._translate_response(request.request, response).text # type: ignore
- think_close = text.find(self.separator)
+ think_close = text.rfind(self.CLOSE_TAG)
if think_close == -1:
# no closing tag: everything is thinking text
- return SUTResponse(text="")
+ return SUTResponse(text="", reasoning=self.trim_tokens(text))
- reasoning = text[: think_close + len(self.separator)].strip()
- content = text[think_close + len(self.separator) :].strip()
+ reasoning = text[: think_close + len(self.CLOSE_TAG)].strip()
+ content = text[think_close + len(self.CLOSE_TAG) :].strip()
self.warn_edge_cases(content, reasoning, request)
+ reasoning = self.trim_tokens(reasoning)
+
# Truncate content
if request.max_content_tokens is not None:
content = self.tokenizer.truncate(content, request.max_content_tokens)
- return SUTResponse(text=content)
+ return SUTResponse(text=content, reasoning=reasoning)
def warn_edge_cases(self, content, reasoning, request):
if request.max_total_tokens is None:
@@ -77,3 +121,10 @@ def warn_edge_cases(self, content, reasoning, request):
logger.warning(
f"SUT {self.uid} reasoning likely ate into the token budget of the actual output. Consider increasing max_total_output_tokens."
)
+
+ def trim_tokens(self, text: str) -> str:
+ if text.startswith(self.OPEN_TAG):
+ text = text[len(self.OPEN_TAG) :]
+ if text.endswith(self.CLOSE_TAG):
+ text = text[: -len(self.CLOSE_TAG)]
+ return text
diff --git a/src/modelgauge/sut.py b/src/modelgauge/sut.py
index 06865a42..750c0fc0 100644
--- a/src/modelgauge/sut.py
+++ b/src/modelgauge/sut.py
@@ -18,6 +18,7 @@ class SUTResponse(BaseModel):
"""The data that came out of the SUT."""
text: str
+ reasoning: Optional[str] = None
top_logprobs: Optional[Sequence[TopTokens]] = None
"""For each position, list the probabilities for each of the most likely tokens.
@@ -58,14 +59,23 @@ class PromptResponseSUT(SUT, Readyable):
Abstract base class that provides an interface to any SUT that is designed for handling a single-turn.
"""
+ def __init__(self, uid: str):
+ super().__init__(uid)
+ self.reasoning_handler: Optional[Type[ReasoningHandler]] = ReasoningHandler.sut_matches(self)
+
def run_readiness_check(self) -> ReadyResponse:
raw_request = self.translate_text_prompt(_READINESS_CHECK_TEXT_PROMPT, options=_READINESS_CHECK_SUT_OPTIONS)
raw_response = self.evaluate(raw_request)
response = self.translate_response(raw_request, raw_response)
return ReadyResponse(is_ready=response.text is not None, response=response)
- @not_implemented
def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions):
+ if self.reasoning_handler is not None:
+ return self.reasoning_handler.translate_text_prompt(self, prompt, options)
+ return self._translate_text_prompt(prompt, options)
+
+ @not_implemented
+ def _translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions):
"""Convert the prompt + SUT options into the SUT's native representation.
This method must be implemented if the SUT accepts text prompts.
@@ -80,12 +90,24 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions):
"""
raise NotImplementedError(f"SUT {self.__class__.__name__} does not implement translate_chat_prompt.")
- @abstractmethod
def evaluate(self, request):
"""Evaluate this SUT on the native request."""
- pass
+ if self.reasoning_handler is not None:
+ return self.reasoning_handler.evaluate(self, request)
+ return self._evaluate(request)
@abstractmethod
+ def _evaluate(self, request):
+ """Evaluate this SUT on the native request."""
+ pass
+
def translate_response(self, request, response) -> SUTResponse:
+ """Convert the native response into a form all Tests can process."""
+ if self.reasoning_handler is not None:
+ return self.reasoning_handler._translate_response(self, request, response)
+ return self._translate_response(request, response)
+
+ @abstractmethod
+ def _translate_response(self, request, response) -> SUTResponse:
"""Convert the native response into a form all Tests can process."""
pass
diff --git a/src/modelgauge/suts/huggingface_chat_completion.py b/src/modelgauge/suts/huggingface_chat_completion.py
index 2f3eb0f5..08037414 100644
--- a/src/modelgauge/suts/huggingface_chat_completion.py
+++ b/src/modelgauge/suts/huggingface_chat_completion.py
@@ -11,7 +11,6 @@
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.prompt import TextPrompt, ChatPrompt
-from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
@@ -186,14 +185,6 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
)
-@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
-class HuggingFaceChatCompletionDedicatedThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionDedicatedSUT):
- """
- A SUT that excludes the reasoning from model output.
- Reasoning must be seperated from normal output with a tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)
- """
-
-
@modelgauge_sut(capabilities=[AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT):
"""A SUT hosted by an inference provider on huggingface."""
@@ -231,14 +222,6 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu
)
-@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
-class HuggingFaceChatCompletionServerlessThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionServerlessSUT):
- """
- A SUT that excludes the reasoning from model output.
- Reasoning must be seperated from normal output with a tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16)
- """
-
-
HF_SECRET = InjectSecret(HuggingFaceInferenceToken)
SUTS.register(
@@ -276,29 +259,7 @@ class HuggingFaceChatCompletionServerlessThinkingSUT(ThinkingMixin, HuggingFaceC
None,
HF_SECRET,
)
-# Special thinking dedicated SUTs
-SUTS.register(
- HuggingFaceChatCompletionDedicatedThinkingSUT,
- "nvidia-nemotron-3-nano-30b-a-thinking-excluded-hf",
- "nvidia-nemotron-3-nano-30b-a-mia",
- "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
- HF_SECRET,
-)
-SUTS.register(
- HuggingFaceChatCompletionDedicatedThinkingSUT,
- "PrimeIntellect-INTELLECT-3-thinking-excluded-hf",
- "intellect-3-uqs",
- "PrimeIntellect/INTELLECT-3",
- HF_SECRET,
-)
-# Special thinking serverless SUTs
-SUTS.register(
- HuggingFaceChatCompletionServerlessThinkingSUT,
- "moonshotai/Kimi-K2.5-together-thinking-excluded-hf",
- "moonshotai/Kimi-K2.5",
- "together",
- HF_SECRET,
-)
+
# Register serverless SUTs.
SUTS.register(
HuggingFaceChatCompletionServerlessSUT,
diff --git a/src/modelgauge/suts/together_client.py b/src/modelgauge/suts/together_client.py
index c1103f3e..01eee3af 100644
--- a/src/modelgauge/suts/together_client.py
+++ b/src/modelgauge/suts/together_client.py
@@ -11,7 +11,6 @@
from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens
from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt
from modelgauge.prompt_formatting import format_chat
-from modelgauge.reasoning_handlers import ThinkingMixin
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities
@@ -271,11 +270,6 @@ def translate_response(self, request: TogetherChatRequest, response: TogetherCha
return SUTResponse(text=text, top_logprobs=logprobs)
-@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt])
-class TogetherThinkingSUT(ThinkingMixin, TogetherChatSUT):
- """SUT that preforms reasoning like deepseek-r1"""
-
-
@modelgauge_sut(
capabilities=[
AcceptsTextPrompt,
@@ -382,5 +376,3 @@ def evaluate(self, request: TogetherChatRequest) -> TogetherChatResponse:
}
for uid, model_name in DEDICATED_CHAT_MODELS.items():
SUTS.register(TogetherDedicatedChatSUT, uid, model_name, InjectSecret(TogetherApiKey))
-
-SUTS.register(TogetherThinkingSUT, "deepseek-R1-thinking", "deepseek-ai/DeepSeek-R1", InjectSecret(TogetherApiKey))
diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py
index b842d55f..4901cb42 100644
--- a/tests/modelbench_tests/test_run.py
+++ b/tests/modelbench_tests/test_run.py
@@ -301,6 +301,7 @@ def runner(self, run_dir):
def invoke(command, args=None, **kwargs):
args = list(args or [])
full_args = ["--run-path", run_dir] + args
+ print(command, full_args, kwargs)
return runner.invoke(command, full_args, **kwargs)
return invoke
@@ -316,7 +317,7 @@ def invoke(command, args=None, **kwargs):
],
# TODO add more locales as we add support for them
)
- @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
+ @pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_benchmark_basic_run_produces_json(
self,
monkeypatch,
@@ -396,7 +397,7 @@ def test_benchmark_basic_run_produces_json(
],
# TODO add more locales as we add support for them
)
- @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay;mt=500;t=0.3"])
+ @pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_benchmark_multiple_suts_produces_json(
self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, run_dir, monkeypatch
):
@@ -546,7 +547,7 @@ def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_run_ben
#
# benchmark_arg = mock_score_benchmarks.call_args.args[0][0]
# assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmark)
- @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
+ @pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid):
_ = runner(cli, ["benchmark", "general", "--sut", sut_uid])
@@ -555,14 +556,14 @@ def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid):
assert benchmark_arg.locale == EN_US
assert benchmark_arg.prompt_set == "demo"
- @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
+ @pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner, sut_uid):
result = runner(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid])
assert result.exit_code == 2
assert "Invalid value for '--prompt-set'" in result.output
@pytest.mark.parametrize("prompt_set", GENERAL_PROMPT_SETS.keys())
- @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"])
+ @pytest.mark.parametrize("sut_uid", ["fake-sut"])
def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_run_benchmarks, prompt_set, sut_uid):
_ = runner(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid])
diff --git a/tests/modelgauge_tests/test_reasoning_handlers.py b/tests/modelgauge_tests/test_reasoning_handlers.py
index e2d21996..95bb42ea 100644
--- a/tests/modelgauge_tests/test_reasoning_handlers.py
+++ b/tests/modelgauge_tests/test_reasoning_handlers.py
@@ -1,10 +1,11 @@
import pytest
+from unittest.mock import patch
from pydantic import BaseModel
from modelgauge.model_options import ModelOptions
from modelgauge.prompt import TextPrompt
-from modelgauge.reasoning_handlers import ReasoningRequest, ThinkingMixin
+from modelgauge.reasoning_handlers import ReasoningRequest, ReasoningSUT, ThinkingMixin
from modelgauge.sut import SUTResponse, PromptResponseSUT
from modelgauge.sut_capabilities import AcceptsTextPrompt
@@ -34,14 +35,67 @@ def translate_response(self, request: FakeSUTRequest, response: FakeSUTResponse)
return SUTResponse(text=response.text)
+class TestReasoningSUT:
+
+ class CountMixin(ReasoningSUT, FakeBaseSUT):
+ # Inherit from FakeBaseSUT so that this is a concrete class.
+ @classmethod
+ def response_contains_reasoning(cls, response: SUTResponse) -> bool:
+ return "123" in response.text
+
+ @pytest.fixture(autouse=True)
+ def _patch_reasoning_suts(self):
+ # Only consider the CountMixin for matching.
+ with patch.object(
+ ReasoningSUT,
+ "_get_concrete_reasoning_suts",
+ return_value={self.CountMixin},
+ ):
+ yield
+
+ def test_find_thinking_mixin(self):
+ class CountSUT(FakeBaseSUT):
+ def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse:
+ return FakeSUTResponse(text="123")
+
+ sut = CountSUT("sut")
+ reasoning_cls = ReasoningSUT.find_match(sut)
+ assert reasoning_cls == self.CountMixin
+
+ def test_find_no_match(self):
+ class NoReasoningSUT(FakeBaseSUT):
+ def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse:
+ return FakeSUTResponse(text="text only")
+
+ sut = NoReasoningSUT("sut")
+ reasoning_cls = ReasoningSUT.find_match(sut)
+ assert reasoning_cls is None
+
+
class TestThinkMixin:
+ @modelgauge_sut(capabilities=[AcceptsTextPrompt])
+ class ThinkSut(ThinkingMixin, FakeBaseSUT):
+ pass
+
@pytest.fixture
def sut(self):
- @modelgauge_sut(capabilities=[AcceptsTextPrompt])
- class ThinkSut(ThinkingMixin, FakeBaseSUT):
- pass
+ return self.ThinkSut("sut-uid")
+
+ def test_response_contains_reasoning(self):
+ response = SUTResponse(text="reasoningoutput")
+ assert self.ThinkSut.response_contains_reasoning(response) is True
+
+ response = SUTResponse(text="reasoningoutput")
+ assert self.ThinkSut.response_contains_reasoning(response) is True
+
+ response = SUTResponse(text=" only thinking")
+ assert self.ThinkSut.response_contains_reasoning(response) is True
- return ThinkSut("sut-uid")
+ response = SUTResponse(text="content")
+ assert self.ThinkSut.response_contains_reasoning(response) is False
+
+ response = SUTResponse(text="")
+ assert self.ThinkSut.response_contains_reasoning(response) is False
def test_translate_text_prompt_sets_max_tokens(self, sut):
prompt = TextPrompt(text="some-text")
@@ -67,10 +121,21 @@ def test_translate_text_prompt_sets_max_tokens(self, sut):
assert request.max_content_tokens == None
@pytest.mark.parametrize(
- "full_text, content_text",
- [("hmm\n Output", "Output"), ("hmm\n Output", "Output"), ("hmmm", "")],
+ "full_text, content_text, reason_text",
+ [
+ ("hmm\n Output", "Output", "hmm"),
+ (
+ "hmm nested think> \n Output",
+ "Output",
+ "hmm nested think> ",
+ ),
+ ("hmmmore think Output", "Output", "hmmmore think"),
+ ("hmm\n Output", "Output", "hmm"),
+ ("hmmm", "", "hmmm"),
+ ("", "", ""),
+ ],
)
- def test_translate_response_no_truncation(self, full_text, content_text, sut):
+ def test_translate_response_no_truncation(self, full_text, content_text, reason_text, sut):
request = ReasoningRequest(
request=FakeSUTRequest(text="", max_tokens=100), max_content_tokens=100, max_total_tokens=100
)
@@ -84,6 +149,7 @@ def test_translate_response_no_truncation(self, full_text, content_text, sut):
result = sut.translate_response(request, response)
assert result.text == content_text
+ assert result.reasoning == reason_text
@pytest.mark.parametrize(
"full_text, content_text",
diff --git a/tests/modelgauge_tests/test_records.py b/tests/modelgauge_tests/test_records.py
index 05769675..96136636 100644
--- a/tests/modelgauge_tests/test_records.py
+++ b/tests/modelgauge_tests/test_records.py
@@ -136,6 +136,7 @@ def test_serialize_test_record():
},
"sut_response": {
"text": "sut-completion",
+ "reasoning": null,
"top_logprobs": null
},
"annotations": {