diff --git a/src/modelgauge/general.py b/src/modelgauge/general.py index 5400d28b..866196e1 100644 --- a/src/modelgauge/general.py +++ b/src/modelgauge/general.py @@ -5,13 +5,13 @@ import shlex import subprocess import time -from typing import List, Optional, Set, Type, TypeVar +from typing import List, Optional, TypeVar from airrlogger.log_config import get_logger from tqdm import tqdm # Type vars helpful in defining templates. -_InT = TypeVar("_InT") +_InT = TypeVar("_InT", bound=type) logger = get_logger(__name__) @@ -20,8 +20,8 @@ def current_timestamp_millis() -> int: return time.time_ns() // 1_000_000 -def get_concrete_subclasses(cls: Type[_InT]) -> Set[Type[_InT]]: - result = set() +def get_concrete_subclasses(cls: _InT) -> set[_InT]: + result: set[_InT] = set() for subclass in cls.__subclasses__(): if not inspect.isabstract(subclass): result.add(subclass) diff --git a/src/modelgauge/reasoning_handlers.py b/src/modelgauge/reasoning_handlers.py index 95ec0e41..3d2b8ee4 100644 --- a/src/modelgauge/reasoning_handlers.py +++ b/src/modelgauge/reasoning_handlers.py @@ -1,8 +1,10 @@ +from abc import ABC, abstractmethod from typing import Any from airrlogger.log_config import get_logger from pydantic import BaseModel +from modelgauge.general import get_concrete_subclasses from modelgauge.model_options import ModelOptions from modelgauge.prompt import TextPrompt from modelgauge.sut import PromptResponseSUT, SUTResponse @@ -17,21 +19,61 @@ class ReasoningRequest(BaseModel): max_total_tokens: int | None = None # Total number of tokens allowed (thinking + content). -class ThinkingMixin(PromptResponseSUT): +class ReasoningHandler(ABC): + @staticmethod + def _get_concrete_reasoning_suts() -> set[type["ReasoningHandler"]]: + return get_concrete_subclasses(ReasoningHandler) + + @staticmethod + def find_match(sut: PromptResponseSUT) -> type["ReasoningHandler"] | None: + reasoning_suts = ReasoningHandler._get_concrete_reasoning_suts() + for rs in reasoning_suts: + if rs.sut_matches(sut): + return rs + return None + + @classmethod + def sut_matches(cls, sut) -> bool: + """Finds a matching reasoning handler for the given SUT. Calling this method will result in 1 SUT call.""" + request = sut.translate_text_prompt( + TextPrompt(text="If I have 2 apples and give 1 to my friend, how many apples do I have left?"), + options=ModelOptions(max_tokens=1000), + ) + raw_response = sut.evaluate(request) + response = sut.translate_response(request, raw_response) + return cls.response_contains_reasoning(response) + + @classmethod + @abstractmethod + def response_contains_reasoning(cls, response: SUTResponse) -> bool: + pass + + +class ThinkingMixin(ReasoningHandler): """ A mixin for SUTs that parses out thinking text from the output. The output is expected to be in the form: {reasoning text}{content text}. If max_total_output_tokens is set in ModelOptions, that value will be used in the model call and the content text will be truncated to max_tokens. Otherwise, max_tokens is used in the model call and everything after is returned as content. + + Reasoning should be enabled by the model by default. This mixin does not request reasoning be enabled (yet). """ + OPEN_TAG = "" # Optional. + CLOSE_TAG = "" # Tag that separates reasoning from content. + def __init__(self, uid, *args, **kwargs): super().__init__(uid, *args, **kwargs) self.tokenizer = GeneralTokenizer() - self.separator = "" # Tag that separates reasoning from content. - def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> ReasoningRequest: + @classmethod + def response_contains_reasoning(cls, response: SUTResponse) -> bool: + return cls.OPEN_TAG in response.text or cls.CLOSE_TAG in response.text + + def translate_text_prompt( + self, sut: PromptResponseSUT, prompt: TextPrompt, options: ModelOptions + ) -> ReasoningRequest: max_total_tokens = options.max_total_output_tokens if max_total_tokens is None: max_total_tokens = options.max_tokens @@ -39,32 +81,34 @@ def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> Re # Replace max_tokens in raw request with the max total tokens. options.max_tokens = max_total_tokens - request = super().translate_text_prompt(prompt, options) + request = sut.translate_text_prompt(prompt, options) return ReasoningRequest( request=request, max_content_tokens=max_content_tokens, max_total_tokens=max_total_tokens, ) - def evaluate(self, request: ReasoningRequest) -> Any: - return super().evaluate(request.request) # type: ignore + def evaluate(self, sut: PromptResponseSUT, request: ReasoningRequest) -> Any: + return sut._evaluate(request.request) # type: ignore - def translate_response(self, request: ReasoningRequest, response: Any) -> SUTResponse: - text = super().translate_response(request.request, response).text # type: ignore + def translate_response(self, sut: PromptResponseSUT, request: ReasoningRequest, response: Any) -> SUTResponse: + text = sut._translate_response(request.request, response).text # type: ignore - think_close = text.find(self.separator) + think_close = text.rfind(self.CLOSE_TAG) if think_close == -1: # no closing tag: everything is thinking text - return SUTResponse(text="") + return SUTResponse(text="", reasoning=self.trim_tokens(text)) - reasoning = text[: think_close + len(self.separator)].strip() - content = text[think_close + len(self.separator) :].strip() + reasoning = text[: think_close + len(self.CLOSE_TAG)].strip() + content = text[think_close + len(self.CLOSE_TAG) :].strip() self.warn_edge_cases(content, reasoning, request) + reasoning = self.trim_tokens(reasoning) + # Truncate content if request.max_content_tokens is not None: content = self.tokenizer.truncate(content, request.max_content_tokens) - return SUTResponse(text=content) + return SUTResponse(text=content, reasoning=reasoning) def warn_edge_cases(self, content, reasoning, request): if request.max_total_tokens is None: @@ -77,3 +121,10 @@ def warn_edge_cases(self, content, reasoning, request): logger.warning( f"SUT {self.uid} reasoning likely ate into the token budget of the actual output. Consider increasing max_total_output_tokens." ) + + def trim_tokens(self, text: str) -> str: + if text.startswith(self.OPEN_TAG): + text = text[len(self.OPEN_TAG) :] + if text.endswith(self.CLOSE_TAG): + text = text[: -len(self.CLOSE_TAG)] + return text diff --git a/src/modelgauge/sut.py b/src/modelgauge/sut.py index 06865a42..750c0fc0 100644 --- a/src/modelgauge/sut.py +++ b/src/modelgauge/sut.py @@ -18,6 +18,7 @@ class SUTResponse(BaseModel): """The data that came out of the SUT.""" text: str + reasoning: Optional[str] = None top_logprobs: Optional[Sequence[TopTokens]] = None """For each position, list the probabilities for each of the most likely tokens. @@ -58,14 +59,23 @@ class PromptResponseSUT(SUT, Readyable): Abstract base class that provides an interface to any SUT that is designed for handling a single-turn. """ + def __init__(self, uid: str): + super().__init__(uid) + self.reasoning_handler: Optional[Type[ReasoningHandler]] = ReasoningHandler.sut_matches(self) + def run_readiness_check(self) -> ReadyResponse: raw_request = self.translate_text_prompt(_READINESS_CHECK_TEXT_PROMPT, options=_READINESS_CHECK_SUT_OPTIONS) raw_response = self.evaluate(raw_request) response = self.translate_response(raw_request, raw_response) return ReadyResponse(is_ready=response.text is not None, response=response) - @not_implemented def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions): + if self.reasoning_handler is not None: + return self.reasoning_handler.translate_text_prompt(self, prompt, options) + return self._translate_text_prompt(prompt, options) + + @not_implemented + def _translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions): """Convert the prompt + SUT options into the SUT's native representation. This method must be implemented if the SUT accepts text prompts. @@ -80,12 +90,24 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions): """ raise NotImplementedError(f"SUT {self.__class__.__name__} does not implement translate_chat_prompt.") - @abstractmethod def evaluate(self, request): """Evaluate this SUT on the native request.""" - pass + if self.reasoning_handler is not None: + return self.reasoning_handler.evaluate(self, request) + return self._evaluate(request) @abstractmethod + def _evaluate(self, request): + """Evaluate this SUT on the native request.""" + pass + def translate_response(self, request, response) -> SUTResponse: + """Convert the native response into a form all Tests can process.""" + if self.reasoning_handler is not None: + return self.reasoning_handler._translate_response(self, request, response) + return self._translate_response(request, response) + + @abstractmethod + def _translate_response(self, request, response) -> SUTResponse: """Convert the native response into a form all Tests can process.""" pass diff --git a/src/modelgauge/suts/huggingface_chat_completion.py b/src/modelgauge/suts/huggingface_chat_completion.py index 2f3eb0f5..08037414 100644 --- a/src/modelgauge/suts/huggingface_chat_completion.py +++ b/src/modelgauge/suts/huggingface_chat_completion.py @@ -11,7 +11,6 @@ from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.prompt import TextPrompt, ChatPrompt -from modelgauge.reasoning_handlers import ThinkingMixin from modelgauge.retry_decorator import retry from modelgauge.secret_values import InjectSecret from modelgauge.sut import PromptResponseSUT, SUTResponse @@ -186,14 +185,6 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu ) -@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt]) -class HuggingFaceChatCompletionDedicatedThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionDedicatedSUT): - """ - A SUT that excludes the reasoning from model output. - Reasoning must be seperated from normal output with a tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) - """ - - @modelgauge_sut(capabilities=[AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities]) class HuggingFaceChatCompletionServerlessSUT(BaseHuggingFaceChatCompletionSUT): """A SUT hosted by an inference provider on huggingface.""" @@ -231,14 +222,6 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: ModelOptions) -> Hu ) -@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt]) -class HuggingFaceChatCompletionServerlessThinkingSUT(ThinkingMixin, HuggingFaceChatCompletionServerlessSUT): - """ - A SUT that excludes the reasoning from model output. - Reasoning must be seperated from normal output with a tag (like nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) - """ - - HF_SECRET = InjectSecret(HuggingFaceInferenceToken) SUTS.register( @@ -276,29 +259,7 @@ class HuggingFaceChatCompletionServerlessThinkingSUT(ThinkingMixin, HuggingFaceC None, HF_SECRET, ) -# Special thinking dedicated SUTs -SUTS.register( - HuggingFaceChatCompletionDedicatedThinkingSUT, - "nvidia-nemotron-3-nano-30b-a-thinking-excluded-hf", - "nvidia-nemotron-3-nano-30b-a-mia", - "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", - HF_SECRET, -) -SUTS.register( - HuggingFaceChatCompletionDedicatedThinkingSUT, - "PrimeIntellect-INTELLECT-3-thinking-excluded-hf", - "intellect-3-uqs", - "PrimeIntellect/INTELLECT-3", - HF_SECRET, -) -# Special thinking serverless SUTs -SUTS.register( - HuggingFaceChatCompletionServerlessThinkingSUT, - "moonshotai/Kimi-K2.5-together-thinking-excluded-hf", - "moonshotai/Kimi-K2.5", - "together", - HF_SECRET, -) + # Register serverless SUTs. SUTS.register( HuggingFaceChatCompletionServerlessSUT, diff --git a/src/modelgauge/suts/together_client.py b/src/modelgauge/suts/together_client.py index c1103f3e..01eee3af 100644 --- a/src/modelgauge/suts/together_client.py +++ b/src/modelgauge/suts/together_client.py @@ -11,7 +11,6 @@ from modelgauge.model_options import ModelOptions, TokenProbability, TopTokens from modelgauge.prompt import ChatPrompt, ChatRole, TextPrompt from modelgauge.prompt_formatting import format_chat -from modelgauge.reasoning_handlers import ThinkingMixin from modelgauge.secret_values import InjectSecret from modelgauge.sut import PromptResponseSUT, SUTResponse from modelgauge.sut_capabilities import AcceptsChatPrompt, AcceptsTextPrompt, ProducesPerTokenLogProbabilities @@ -271,11 +270,6 @@ def translate_response(self, request: TogetherChatRequest, response: TogetherCha return SUTResponse(text=text, top_logprobs=logprobs) -@modelgauge_sut(capabilities=[AcceptsTextPrompt, AcceptsChatPrompt]) -class TogetherThinkingSUT(ThinkingMixin, TogetherChatSUT): - """SUT that preforms reasoning like deepseek-r1""" - - @modelgauge_sut( capabilities=[ AcceptsTextPrompt, @@ -382,5 +376,3 @@ def evaluate(self, request: TogetherChatRequest) -> TogetherChatResponse: } for uid, model_name in DEDICATED_CHAT_MODELS.items(): SUTS.register(TogetherDedicatedChatSUT, uid, model_name, InjectSecret(TogetherApiKey)) - -SUTS.register(TogetherThinkingSUT, "deepseek-R1-thinking", "deepseek-ai/DeepSeek-R1", InjectSecret(TogetherApiKey)) diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index b842d55f..4901cb42 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -301,6 +301,7 @@ def runner(self, run_dir): def invoke(command, args=None, **kwargs): args = list(args or []) full_args = ["--run-path", run_dir] + args + print(command, full_args, kwargs) return runner.invoke(command, full_args, **kwargs) return invoke @@ -316,7 +317,7 @@ def invoke(command, args=None, **kwargs): ], # TODO add more locales as we add support for them ) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_benchmark_basic_run_produces_json( self, monkeypatch, @@ -396,7 +397,7 @@ def test_benchmark_basic_run_produces_json( ], # TODO add more locales as we add support for them ) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay;mt=500;t=0.3"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_benchmark_multiple_suts_produces_json( self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, run_dir, monkeypatch ): @@ -546,7 +547,7 @@ def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_run_ben # # benchmark_arg = mock_score_benchmarks.call_args.args[0][0] # assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmark) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid): _ = runner(cli, ["benchmark", "general", "--sut", sut_uid]) @@ -555,14 +556,14 @@ def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid): assert benchmark_arg.locale == EN_US assert benchmark_arg.prompt_set == "demo" - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner, sut_uid): result = runner(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid]) assert result.exit_code == 2 assert "Invalid value for '--prompt-set'" in result.output @pytest.mark.parametrize("prompt_set", GENERAL_PROMPT_SETS.keys()) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_run_benchmarks, prompt_set, sut_uid): _ = runner(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid]) diff --git a/tests/modelgauge_tests/test_reasoning_handlers.py b/tests/modelgauge_tests/test_reasoning_handlers.py index e2d21996..95bb42ea 100644 --- a/tests/modelgauge_tests/test_reasoning_handlers.py +++ b/tests/modelgauge_tests/test_reasoning_handlers.py @@ -1,10 +1,11 @@ import pytest +from unittest.mock import patch from pydantic import BaseModel from modelgauge.model_options import ModelOptions from modelgauge.prompt import TextPrompt -from modelgauge.reasoning_handlers import ReasoningRequest, ThinkingMixin +from modelgauge.reasoning_handlers import ReasoningRequest, ReasoningSUT, ThinkingMixin from modelgauge.sut import SUTResponse, PromptResponseSUT from modelgauge.sut_capabilities import AcceptsTextPrompt @@ -34,14 +35,67 @@ def translate_response(self, request: FakeSUTRequest, response: FakeSUTResponse) return SUTResponse(text=response.text) +class TestReasoningSUT: + + class CountMixin(ReasoningSUT, FakeBaseSUT): + # Inherit from FakeBaseSUT so that this is a concrete class. + @classmethod + def response_contains_reasoning(cls, response: SUTResponse) -> bool: + return "123" in response.text + + @pytest.fixture(autouse=True) + def _patch_reasoning_suts(self): + # Only consider the CountMixin for matching. + with patch.object( + ReasoningSUT, + "_get_concrete_reasoning_suts", + return_value={self.CountMixin}, + ): + yield + + def test_find_thinking_mixin(self): + class CountSUT(FakeBaseSUT): + def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: + return FakeSUTResponse(text="123") + + sut = CountSUT("sut") + reasoning_cls = ReasoningSUT.find_match(sut) + assert reasoning_cls == self.CountMixin + + def test_find_no_match(self): + class NoReasoningSUT(FakeBaseSUT): + def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: + return FakeSUTResponse(text="text only") + + sut = NoReasoningSUT("sut") + reasoning_cls = ReasoningSUT.find_match(sut) + assert reasoning_cls is None + + class TestThinkMixin: + @modelgauge_sut(capabilities=[AcceptsTextPrompt]) + class ThinkSut(ThinkingMixin, FakeBaseSUT): + pass + @pytest.fixture def sut(self): - @modelgauge_sut(capabilities=[AcceptsTextPrompt]) - class ThinkSut(ThinkingMixin, FakeBaseSUT): - pass + return self.ThinkSut("sut-uid") + + def test_response_contains_reasoning(self): + response = SUTResponse(text="reasoningoutput") + assert self.ThinkSut.response_contains_reasoning(response) is True + + response = SUTResponse(text="reasoningoutput") + assert self.ThinkSut.response_contains_reasoning(response) is True + + response = SUTResponse(text=" only thinking") + assert self.ThinkSut.response_contains_reasoning(response) is True - return ThinkSut("sut-uid") + response = SUTResponse(text="content") + assert self.ThinkSut.response_contains_reasoning(response) is False + + response = SUTResponse(text="") + assert self.ThinkSut.response_contains_reasoning(response) is False def test_translate_text_prompt_sets_max_tokens(self, sut): prompt = TextPrompt(text="some-text") @@ -67,10 +121,21 @@ def test_translate_text_prompt_sets_max_tokens(self, sut): assert request.max_content_tokens == None @pytest.mark.parametrize( - "full_text, content_text", - [("hmm\n Output", "Output"), ("hmm\n Output", "Output"), ("hmmm", "")], + "full_text, content_text, reason_text", + [ + ("hmm\n Output", "Output", "hmm"), + ( + "hmm nested think> \n Output", + "Output", + "hmm nested think> ", + ), + ("hmmmore think Output", "Output", "hmmmore think"), + ("hmm\n Output", "Output", "hmm"), + ("hmmm", "", "hmmm"), + ("", "", ""), + ], ) - def test_translate_response_no_truncation(self, full_text, content_text, sut): + def test_translate_response_no_truncation(self, full_text, content_text, reason_text, sut): request = ReasoningRequest( request=FakeSUTRequest(text="", max_tokens=100), max_content_tokens=100, max_total_tokens=100 ) @@ -84,6 +149,7 @@ def test_translate_response_no_truncation(self, full_text, content_text, sut): result = sut.translate_response(request, response) assert result.text == content_text + assert result.reasoning == reason_text @pytest.mark.parametrize( "full_text, content_text", diff --git a/tests/modelgauge_tests/test_records.py b/tests/modelgauge_tests/test_records.py index 05769675..96136636 100644 --- a/tests/modelgauge_tests/test_records.py +++ b/tests/modelgauge_tests/test_records.py @@ -136,6 +136,7 @@ def test_serialize_test_record(): }, "sut_response": { "text": "sut-completion", + "reasoning": null, "top_logprobs": null }, "annotations": {