diff --git a/docs/add-a-new-sut-driver.md b/docs/add-a-new-sut-driver.md index 853a0740..f7699c75 100644 --- a/docs/add-a-new-sut-driver.md +++ b/docs/add-a-new-sut-driver.md @@ -7,19 +7,8 @@ Most providers need their own driver. We provide several drivers that can be use ### Does an Existing Driver Exist? -If your SUT provider is listed as a key in the `DYNAMIC_SUT_FACTORIES` in -[sut_factory](../src/modelgauge/sut_factory.py), you don't need to write any code. - -```python -DYNAMIC_SUT_FACTORIES: dict = { - "hf": HuggingFaceSUTFactory, - "hfrelay": HuggingFaceSUTFactory, - "openai": OpenAICompatibleSUTFactory, - "together": TogetherSUTFactory, - "modelship": ModelShipSUTFactory, -} -``` -Please refer to [suts-how-to.md](./suts-how-to.md#existing) for details. +Search the existing `DynamicSUTFactoryDriver` classes; +if one already exists for your provider, you can don't need to write any code. ### Is Your SUT Already Pre-Defined? @@ -79,13 +68,11 @@ class MySUT(PromptResponseSUT): return MySUTResponse(**response_json) ``` -2. Create a factory class that creates an instance of your SUT from its UID. Look at [TogetherSUTFactory](../src/modelgauge/suts/together_sut_factory.py) for inspiration. +2. Create a factory class that creates an instance of your SUT from its UID. Look at [TogetherSUTFactoryDriver](../src/modelgauge/suts/together_sut_factory.py) for inspiration. The `DRIVER_NAME` constant must be unique to your driver. It will be a key in a dict. ```python -DRIVER_NAME = "my_sut" - class MySUTApiKey(RequiredSecret): # adjust this to your specific provider @classmethod @@ -95,7 +82,9 @@ class MySUTApiKey(RequiredSecret): key="api_key" ) -class MySUTFactory(DynamicSUTFactory): +class MySUTFactory(DynamicSUTFactoryDriver): + DRIVER_NAME = "my_sut" + def __init__(self, raw_secrets: RawSecrets): # RawSecrets is a dict of secrets super().__init__(raw_secrets) @@ -112,17 +101,7 @@ class MySUTFactory(DynamicSUTFactory): ) ``` -3. Add an entry for your new factory class in the `DYNAMIC_SUT_FACTORIES` dict in [sut_factory](../src/modelgauge/sut_factory.py). - -```python -DYNAMIC_SUT_FACTORIES: dict = { - ... - "my_sut": MySUTFactory, - ... -} -``` - -4. Add a scope to [config/secrets.toml](../config/secrets.toml) for your provider, using the `scope` you defined in the `Secret` class(es) for your SUT: +3. Add a scope to [config/secrets.toml](../config/secrets.toml) for your provider, using the `scope` you defined in the `Secret` class(es) for your SUT: ```toml [my_host] diff --git a/docs/suts-how-to.md b/docs/suts-how-to.md index b0bfd0b4..02df7ff4 100644 --- a/docs/suts-how-to.md +++ b/docs/suts-how-to.md @@ -32,27 +32,6 @@ A lot of new SUTs will require no code if your model is hosted on one of the pro Factory classes are used to create SUT objects for you, including their driver and model name, based the elements in the SUT UID. -Available drivers are identified in `DYNAMIC_SUT_FACTORIES` in -[sut_factory](../src/modelgauge/sut_factory.py). The keys correspond to the `driver` string in the SUT UID. - -We may add more drivers from time to time. - -```python -DYNAMIC_SUT_FACTORIES: dict = { - "hf": HuggingFaceSUTFactory, - "hfrelay": HuggingFaceSUTFactory, - "openai": OpenAICompatibleSUTFactory, - "together": TogetherSUTFactory, - "modelship": ModelShipSUTFactory, -} -``` - -* "hf" is used for models hosted by Huggingface -* "hfrelay" is used for models hosted by one of Huggingface's inference provider partners (e.g. nebius, sambanova) and proxied by Huggingface ([more info](https://huggingface.co/docs/inference-providers/en/index)) -* "openai" is a model hosted by OpenAI -* "together" is a model hosted by together.ai -* "modelship" is internal to MLCommons - #### Usage For models using one of those drivers, all you need is to add your credentials to [config/secrets.toml](../config/secrets.toml) in a section named after the driver name string, e.g. for together.ai: diff --git a/src/modelgauge/dynamic_sut_factory.py b/src/modelgauge/dynamic_sut_factory.py index 7c300e4b..f5cabbdc 100644 --- a/src/modelgauge/dynamic_sut_factory.py +++ b/src/modelgauge/dynamic_sut_factory.py @@ -43,3 +43,15 @@ def get_secrets(self) -> list[InjectSecret]: def make_sut(self, sut_definition: SUTDefinition) -> SUT: """Factories that handle special SUT config parameters (e.g. moderated, reasoning) must accept them as kwargs.""" pass + + +class DynamicDriverSUTFactory(DynamicSUTFactory, ABC): + """These classes will be collected as driver factories for dynamic SUTs. They may call regular DynamicSUTFactories.""" + + DRIVER_NAME: str + + def __init__(self, raw_secrets: RawSecrets): + super().__init__(raw_secrets) + assert ( + hasattr(self, "DRIVER_NAME") and isinstance(self.DRIVER_NAME, str) and len(self.DRIVER_NAME) > 0 + ), "DynamicDriverSUTFactory subclasses must have a str DRIVER_NAME attribute" diff --git a/src/modelgauge/sut_factory.py b/src/modelgauge/sut_factory.py index 870d5a63..0ab09679 100644 --- a/src/modelgauge/sut_factory.py +++ b/src/modelgauge/sut_factory.py @@ -1,21 +1,13 @@ from enum import Enum from modelgauge.config import load_secrets_from_config -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, UnknownSUTMakerError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, UnknownSUTMakerError +from modelgauge.general import get_concrete_subclasses +from modelgauge.load_namespaces import load_namespace from modelgauge.secret_values import RawSecrets from modelgauge.sut import SUT from modelgauge.sut_definition import SUTDefinition from modelgauge.sut_registry import SUTS -from modelgauge.suts.anthropic_sut_factory import AnthropicSUTFactory -from modelgauge.suts.aws_bedrock_sut_factory import AWSBedrockSUTFactory -from modelgauge.suts.google_sut_factory import GoogleSUTFactory -from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory -from modelgauge.suts.indirect_sut import IndirectSUTFactory -from modelgauge.suts.meta_llama_factory import LlamaSUTFactory -from modelgauge.suts.mistral_sut_factory import MistralSUTFactory -from modelgauge.suts.modelship_sut import ModelShipSUTFactory -from modelgauge.suts.openai_sut_factory import OpenAICompatibleSUTFactory -from modelgauge.suts.together_sut_factory import TogetherSUTFactory class SUTNotFoundException(Exception): @@ -32,24 +24,6 @@ class SUTType(Enum): UNKNOWN = "unknown" -# TODO: Auto-collect? -# Make sure the factory module includes the matching key as a constant. -# Maps a string to the module and factory function in that module -# that can be used to create a dynamic sut -DYNAMIC_SUT_FACTORIES: dict = { - "anthropic": AnthropicSUTFactory, - "aws": AWSBedrockSUTFactory, - "google": GoogleSUTFactory, - "hf": HuggingFaceSUTFactory, - "hfrelay": HuggingFaceSUTFactory, - "indirect": IndirectSUTFactory, - "llama": LlamaSUTFactory, - "openai": OpenAICompatibleSUTFactory, - "mistral": MistralSUTFactory, - "modelship": ModelShipSUTFactory, - "together": TogetherSUTFactory, -} - LEGACY_SUT_MODULE_MAP = { # HuggingFaceChatCompletionDedicatedSUT and HuggingFaceChatCompletionServerlessSUT "nvidia-llama-3-1-nemotron-nano-8b-v1": "huggingface_chat_completion", @@ -157,11 +131,15 @@ def __init__(self, sut_registry): self.sut_registry = sut_registry self.dynamic_sut_factories = self._load_dynamic_sut_factories(load_secrets_from_config()) - def _load_dynamic_sut_factories(self, secrets: RawSecrets) -> dict[str, DynamicSUTFactory]: - factories: dict[str, DynamicSUTFactory] = {} - for driver, factory_class in DYNAMIC_SUT_FACTORIES.items(): - factories[driver] = factory_class(secrets) - return factories + def _load_dynamic_sut_factories(self, secrets: RawSecrets) -> dict[str, DynamicDriverSUTFactory]: + load_namespace("suts") + dynamic_sut_factories = {} + for cls in get_concrete_subclasses(DynamicDriverSUTFactory): # type: ignore + if cls.DRIVER_NAME in dynamic_sut_factories: + raise ValueError(f"Multiple DynamicSUTFactoryDrivers have the same DRIVER_NAME '{cls.DRIVER_NAME}'.") + dynamic_sut_factories[cls.DRIVER_NAME] = cls(secrets) + + return dynamic_sut_factories def knows(self, uid: str) -> bool: """Check if the registry knows about a given SUT UID. Dynamic SUTs are always considered known.""" diff --git a/src/modelgauge/suts/anthropic_sut_factory.py b/src/modelgauge/suts/anthropic_sut_factory.py index 582b438b..948ffc5e 100644 --- a/src/modelgauge/suts/anthropic_sut_factory.py +++ b/src/modelgauge/suts/anthropic_sut_factory.py @@ -4,14 +4,16 @@ from anthropic import Anthropic -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError from modelgauge.secret_values import RawSecrets, InjectSecret from modelgauge.sut import SUT from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.anthropic_api import AnthropicApiKey, AnthropicSUT -class AnthropicSUTFactory(DynamicSUTFactory): +class AnthropicSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "anthropic" + def get_secrets(self) -> list[InjectSecret]: api_key = InjectSecret(AnthropicApiKey) return [api_key] diff --git a/src/modelgauge/suts/aws_bedrock_sut_factory.py b/src/modelgauge/suts/aws_bedrock_sut_factory.py index 774c90d7..3cbc0491 100644 --- a/src/modelgauge/suts/aws_bedrock_sut_factory.py +++ b/src/modelgauge/suts/aws_bedrock_sut_factory.py @@ -2,14 +2,14 @@ import boto3 -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError from modelgauge.secret_values import InjectSecret, RawSecrets from modelgauge.sut import SUT from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.aws_bedrock_client import AmazonBedrockSut, AwsAccessKeyId, AwsSecretAccessKey -class AWSBedrockSUTFactory(DynamicSUTFactory): +class AWSBedrockSUTFactory(DynamicDriverSUTFactory): DRIVER_NAME = "aws" def __init__(self, raw_secrets: RawSecrets): diff --git a/src/modelgauge/suts/google_sut_factory.py b/src/modelgauge/suts/google_sut_factory.py index 588105c4..a30b7075 100644 --- a/src/modelgauge/suts/google_sut_factory.py +++ b/src/modelgauge/suts/google_sut_factory.py @@ -2,16 +2,16 @@ from google import genai -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError from modelgauge.secret_values import RawSecrets, InjectSecret from modelgauge.sut import SUT from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.google_genai import GoogleGenAiSUT, GoogleAiApiKey -DRIVER_NAME = "google" +class GoogleSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "google" -class GoogleSUTFactory(DynamicSUTFactory): def get_secrets(self) -> list[InjectSecret]: api_key = InjectSecret(GoogleAiApiKey) return [api_key] diff --git a/src/modelgauge/suts/huggingface_sut_factory.py b/src/modelgauge/suts/huggingface_sut_factory.py index e84f9577..d6774493 100644 --- a/src/modelgauge/suts/huggingface_sut_factory.py +++ b/src/modelgauge/suts/huggingface_sut_factory.py @@ -4,7 +4,12 @@ from airrlogger.log_config import get_logger from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError, ProviderNotFoundError +from modelgauge.dynamic_sut_factory import ( + DynamicSUTFactory, + DynamicDriverSUTFactory, + ModelNotSupportedError, + ProviderNotFoundError, +) from modelgauge.secret_values import InjectSecret, RawSecrets from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.huggingface_chat_completion import ( @@ -13,15 +18,15 @@ HuggingFaceChatCompletionServerlessSUT, ) -DRIVER_NAME = "hfrelay" - logger = get_logger(__name__) # Set HF logging to ERROR because its default logger level is DEBUG. # There are also many warnings which are not really actionable and very repetitive. logging.getLogger("huggingface_hub").setLevel(logging.ERROR) -class HuggingFaceSUTFactory(DynamicSUTFactory): +class HuggingFaceSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "hf" + def __init__(self, raw_secrets: RawSecrets): super().__init__(raw_secrets) self.serverless_factory = HuggingFaceChatCompletionServerlessSUTFactory(raw_secrets) @@ -45,7 +50,6 @@ def make_sut(self, sut_definition: SUTDefinition) -> BaseHuggingFaceChatCompleti class HuggingFaceChatCompletionServerlessSUTFactory(DynamicSUTFactory): - def get_secrets(self) -> list[InjectSecret]: hf_token = InjectSecret(HuggingFaceInferenceToken) return [hf_token] @@ -89,7 +93,6 @@ def make_sut(self, sut_definition: SUTDefinition) -> HuggingFaceChatCompletionSe class HuggingFaceChatCompletionDedicatedSUTFactory(DynamicSUTFactory): - def get_secrets(self) -> list[InjectSecret]: hf_token = InjectSecret(HuggingFaceInferenceToken) return [hf_token] diff --git a/src/modelgauge/suts/indirect_sut.py b/src/modelgauge/suts/indirect_sut.py index e27408ff..2ced6c5f 100644 --- a/src/modelgauge/suts/indirect_sut.py +++ b/src/modelgauge/suts/indirect_sut.py @@ -5,7 +5,7 @@ import uvicorn from pydantic import BaseModel -from modelgauge.dynamic_sut_factory import DynamicSUTFactory +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory from modelgauge.prompt import TextPrompt from modelgauge.ready import ReadyResponse from modelgauge.secret_values import InjectSecret @@ -137,7 +137,8 @@ def start(): thread.start() -class IndirectSUTFactory(DynamicSUTFactory): +class IndirectSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "indirect" def get_secrets(self) -> list[InjectSecret]: return [] diff --git a/src/modelgauge/suts/meta_llama_factory.py b/src/modelgauge/suts/meta_llama_factory.py index be6e7dbe..623ab16d 100644 --- a/src/modelgauge/suts/meta_llama_factory.py +++ b/src/modelgauge/suts/meta_llama_factory.py @@ -1,13 +1,15 @@ from llama_api_client import LlamaAPIClient -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError from modelgauge.secret_values import InjectSecret, RawSecrets from modelgauge.sut import SUT from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.meta_llama_client import MetaLlamaApiKey, MetaLlamaModeratedSUT, MetaLlamaSUT -class LlamaSUTFactory(DynamicSUTFactory): +class LlamaSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "llama" + def __init__(self, raw_secrets: RawSecrets): super().__init__(raw_secrets) self._client = None diff --git a/src/modelgauge/suts/mistral_sut_factory.py b/src/modelgauge/suts/mistral_sut_factory.py index e79cd8c2..bbc7c21c 100644 --- a/src/modelgauge/suts/mistral_sut_factory.py +++ b/src/modelgauge/suts/mistral_sut_factory.py @@ -1,4 +1,4 @@ -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError from modelgauge.secret_values import InjectSecret, RawSecrets from modelgauge.sut import SUT from modelgauge.sut_definition import SUTDefinition @@ -6,7 +6,9 @@ from modelgauge.suts.mistral_sut import MistralAISut -class MistralSUTFactory(DynamicSUTFactory): +class MistralSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "mistral" + def __init__(self, raw_secrets: RawSecrets): super().__init__(raw_secrets) self._client = None # Lazy load. diff --git a/src/modelgauge/suts/modelship_sut.py b/src/modelgauge/suts/modelship_sut.py index 3d10551e..4b5ae471 100644 --- a/src/modelgauge/suts/modelship_sut.py +++ b/src/modelgauge/suts/modelship_sut.py @@ -1,7 +1,7 @@ from typing import Optional, Mapping, Any from modelgauge.auth.openai_compatible_secrets import OpenAICompatibleApiKey -from modelgauge.dynamic_sut_factory import DynamicSUTFactory +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.openai_client import OpenAIChat, OpenAIChatRequest @@ -34,7 +34,9 @@ def request_as_dict_for_client(self, request: OpenAIChatRequest) -> dict[str, An return request_as_dict -class ModelShipSUTFactory(DynamicSUTFactory): +class ModelShipSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "modelship" + def get_secrets(self) -> list[InjectSecret]: api_key = InjectSecret(ModelShipSecret) return [api_key] diff --git a/src/modelgauge/suts/openai_sut_factory.py b/src/modelgauge/suts/openai_sut_factory.py index cf24045b..e9979277 100644 --- a/src/modelgauge/suts/openai_sut_factory.py +++ b/src/modelgauge/suts/openai_sut_factory.py @@ -1,17 +1,20 @@ from openai import OpenAI, NotFoundError from modelgauge.auth.openai_compatible_secrets import OpenAICompatibleApiKey -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError, ProviderNotFoundError +from modelgauge.dynamic_sut_factory import ( + DynamicSUTFactory, + DynamicDriverSUTFactory, + ModelNotSupportedError, + ProviderNotFoundError, +) from modelgauge.secret_values import InjectSecret, RawSecrets from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.openai_client import OpenAIChat -DRIVER_NAME = "openai" NUM_RETRIES = 7 -class OpenAICompatibleSUTFactory(DynamicSUTFactory): - +class BaseOpenAISUTFactory(DynamicSUTFactory): def __init__(self, raw_secrets: RawSecrets): super().__init__(raw_secrets) self.provider = None # must be set in child classes and match name of section (scope) in secrets.toml @@ -32,6 +35,10 @@ def _make_client(self) -> OpenAI: _client = OpenAI(api_key=api_key.value, max_retries=NUM_RETRIES) return _client + +class OpenAICompatibleSUTFactory(BaseOpenAISUTFactory, DynamicDriverSUTFactory): + DRIVER_NAME = "openai" + def make_sut(self, sut_definition: SUTDefinition) -> OpenAIChat: factory = factory_class = None self.provider = sut_definition.get("provider") # type: ignore @@ -54,7 +61,7 @@ def make_sut(self, sut_definition: SUTDefinition) -> OpenAIChat: return factory.make_sut(sut_definition) -class OpenAISUTFactory(OpenAICompatibleSUTFactory): +class OpenAISUTFactory(BaseOpenAISUTFactory): """OpenAI SUT hosted by OpenAI""" def __init__(self, raw_secrets: RawSecrets): @@ -76,7 +83,7 @@ def make_sut(self, sut_definition: SUTDefinition) -> OpenAIChat: return OpenAIChat(sut_definition.uid, sut_definition.get("model"), client=self.client) # type: ignore -class OpenAIGenericSUTFactory(OpenAICompatibleSUTFactory): +class OpenAIGenericSUTFactory(BaseOpenAISUTFactory): """A SUT that uses the OpenAI client, not hosted by OpenAI""" def __init__(self, raw_secrets: RawSecrets, base_url: str | None = None): diff --git a/src/modelgauge/suts/together_cli.py b/src/modelgauge/suts/together_cli.py deleted file mode 100644 index 73335cfd..00000000 --- a/src/modelgauge/suts/together_cli.py +++ /dev/null @@ -1,30 +0,0 @@ -import together # type: ignore -from collections import defaultdict -from modelgauge.command_line import display_header, display_list_item, cli -from modelgauge.config import load_secrets_from_config -from modelgauge.suts.together_client import TogetherApiKey - - -@cli.command() -def list_together(): - """List all models available in together.ai.""" - - secrets = load_secrets_from_config() - together.api_key = TogetherApiKey.make(secrets).value - model_list = together.Models.list() - - # Group by display_type, which seems to be the model's style. - by_display_type = defaultdict(list) - for model in model_list: - try: - display_type = model["display_type"] - except KeyError: - display_type = "unknown" - display_name = model["display_name"] - by_display_type[display_type].append(f"{display_name}: {model['name']}") - - for display_name, models in by_display_type.items(): - display_header(f"{display_name}: {len(models)}") - for model in sorted(models): - display_list_item(model) - display_header(f"Total: {len(model_list)}") diff --git a/src/modelgauge/suts/together_sut_factory.py b/src/modelgauge/suts/together_sut_factory.py index 0e229287..951b5540 100644 --- a/src/modelgauge/suts/together_sut_factory.py +++ b/src/modelgauge/suts/together_sut_factory.py @@ -1,16 +1,16 @@ from together import Together # type: ignore from modelgauge.auth.together_key import TogetherApiKey -from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError from modelgauge.dynamic_sut_metadata import DynamicSUTMetadata from modelgauge.secret_values import InjectSecret, RawSecrets from modelgauge.sut_definition import SUTDefinition from modelgauge.suts.together_client import TogetherChatSUT -DRIVER_NAME = "together" +class TogetherSUTFactory(DynamicDriverSUTFactory): + DRIVER_NAME = "together" -class TogetherSUTFactory(DynamicSUTFactory): def __init__(self, raw_secrets: RawSecrets): super().__init__(raw_secrets) self._client = None # Lazy load. @@ -50,7 +50,7 @@ def make_sut(self, sut_definition: SUTDefinition) -> TogetherChatSUT: f"Model {sut_metadata.external_model_name()} not found or not available on together." ) - assert sut_metadata.driver == DRIVER_NAME + assert sut_metadata.driver == self.DRIVER_NAME return TogetherChatSUT( sut_definition.dynamic_uid, sut_metadata.external_model_name(), diff --git a/tests/modelbench_tests/test_run.py b/tests/modelbench_tests/test_run.py index b842d55f..09ffaa68 100644 --- a/tests/modelbench_tests/test_run.py +++ b/tests/modelbench_tests/test_run.py @@ -316,7 +316,7 @@ def invoke(command, args=None, **kwargs): ], # TODO add more locales as we add support for them ) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_benchmark_basic_run_produces_json( self, monkeypatch, @@ -396,7 +396,7 @@ def test_benchmark_basic_run_produces_json( ], # TODO add more locales as we add support for them ) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay;mt=500;t=0.3"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_benchmark_multiple_suts_produces_json( self, mock_run_benchmarks, runner, version, locale, prompt_set, sut_uid, run_dir, monkeypatch ): @@ -488,44 +488,6 @@ def test_benchmark_bad_sut_errors_out(self, runner): catch_exceptions=False, ) - with patch( - "modelgauge.suts.huggingface_sut_factory.HuggingFaceChatCompletionServerlessSUTFactory._find", - side_effect=ProviderNotFoundError("bad provider"), - ): - with pytest.raises(ModelNotSupportedError): - _ = runner( - cli, - [ - "benchmark", - "general", - "-m", - "1", - "--sut", - "meta/llama:notreal:hfrelay", - *benchmark_options, - ], - catch_exceptions=False, - ) - - with patch( - "modelgauge.suts.huggingface_sut_factory.hfh.model_info", - side_effect=ModelNotSupportedError("bad model"), - ): - with pytest.raises(ModelNotSupportedError): - _ = runner( - cli, - [ - "benchmark", - "general", - "-m", - "1", - "--sut", - "google/bogus:cohere:hfrelay", - *benchmark_options, - ], - catch_exceptions=False, - ) - @pytest.mark.parametrize("version", ["0.0", "0.5"]) def test_invalid_benchmark_versions_can_not_be_called(self, version, runner): result = runner(cli, ["benchmark", "general", "--version", "0.0"]) @@ -546,7 +508,7 @@ def test_calls_score_benchmark_with_correct_v1_locale(self, runner, mock_run_ben # # benchmark_arg = mock_score_benchmarks.call_args.args[0][0] # assert isinstance(benchmark_arg, GeneralPurposeAiChatBenchmark) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid): _ = runner(cli, ["benchmark", "general", "--sut", sut_uid]) @@ -555,14 +517,14 @@ def test_v1_en_us_demo_is_default(self, runner, mock_run_benchmarks, sut_uid): assert benchmark_arg.locale == EN_US assert benchmark_arg.prompt_set == "demo" - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_nonexistent_benchmark_prompt_sets_can_not_be_called(self, runner, sut_uid): result = runner(cli, ["benchmark", "general", "--prompt-set", "fake", "--sut", sut_uid]) assert result.exit_code == 2 assert "Invalid value for '--prompt-set'" in result.output @pytest.mark.parametrize("prompt_set", GENERAL_PROMPT_SETS.keys()) - @pytest.mark.parametrize("sut_uid", ["fake-sut", "google/gemma-3-27b-it:scaleway:hfrelay"]) + @pytest.mark.parametrize("sut_uid", ["fake-sut"]) def test_calls_score_benchmark_with_correct_prompt_set(self, runner, mock_run_benchmarks, prompt_set, sut_uid): _ = runner(cli, ["benchmark", "general", "--prompt-set", prompt_set, "--sut", sut_uid]) diff --git a/tests/modelgauge_tests/sut_tests/test_openai_sut_factory.py b/tests/modelgauge_tests/sut_tests/test_openai_sut_factory.py index 5c144958..487c4b5a 100644 --- a/tests/modelgauge_tests/sut_tests/test_openai_sut_factory.py +++ b/tests/modelgauge_tests/sut_tests/test_openai_sut_factory.py @@ -120,7 +120,7 @@ def test_factory_tries_to_make_a_generic_sut(factory, sut_definition): def test_factory_makes_the_right_openai_sut(factory): - with patch("modelgauge.suts.openai_sut_factory.OpenAICompatibleSUTFactory._make_client"): + with patch("modelgauge.suts.openai_sut_factory.BaseOpenAISUTFactory._make_client"): sut_definition = SUTDefinition(model="gpt-5", maker="openai", driver="openai") sut = factory.make_sut(sut_definition) assert sut.uid == "openai/gpt-5:openai" diff --git a/tests/modelgauge_tests/test_dynamic_sut_factory.py b/tests/modelgauge_tests/test_dynamic_sut_factory.py index 3a1bf1bc..06f553b4 100644 --- a/tests/modelgauge_tests/test_dynamic_sut_factory.py +++ b/tests/modelgauge_tests/test_dynamic_sut_factory.py @@ -1,6 +1,6 @@ import pytest -from modelgauge.dynamic_sut_factory import DynamicSUTFactory +from modelgauge.dynamic_sut_factory import DynamicSUTFactory, DynamicDriverSUTFactory from modelgauge.sut_definition import SUTDefinition from modelgauge.secret_values import InjectSecret from modelgauge_tests.fake_sut import FakeSUT @@ -46,3 +46,34 @@ def test_injected_secrets_missing_required(): factory = FakeDynamicFactory({"optional-scope": {"optional-key": "optional-value"}}) with pytest.raises(MissingSecretValues): factory.injected_secrets() + + +def test_dynamic_sut_factory_driver_instantiation(): + class MyDriverFactory(FakeDynamicFactory, DynamicDriverSUTFactory): + pass + + with pytest.raises(AssertionError): + MyDriverFactory({}) + + class MyDriverFactory(FakeDynamicFactory, DynamicDriverSUTFactory): + DRIVER_NAME: str + + with pytest.raises(AssertionError): + MyDriverFactory({}) + + class MyDriverFactory(FakeDynamicFactory, DynamicDriverSUTFactory): + DRIVER_NAME = None + + with pytest.raises(AssertionError): + MyDriverFactory({}) + + class MyDriverFactory(FakeDynamicFactory, DynamicDriverSUTFactory): + DRIVER_NAME = "" + + with pytest.raises(AssertionError): + MyDriverFactory({}) + + class MyDriverFactory(FakeDynamicFactory, DynamicDriverSUTFactory): + DRIVER_NAME = "driver" + + factory = MyDriverFactory({}) diff --git a/tests/modelgauge_tests/test_sut_factory.py b/tests/modelgauge_tests/test_sut_factory.py index e0a6f541..88bf4356 100644 --- a/tests/modelgauge_tests/test_sut_factory.py +++ b/tests/modelgauge_tests/test_sut_factory.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from modelgauge.dynamic_sut_factory import UnknownSUTMakerError +from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, UnknownSUTMakerError from modelgauge.instance_factory import InstanceFactory from modelgauge.sut import SUT from modelgauge.sut_factory import IncompatibleSUTParamsError, SUTFactory, SUTNotFoundException, SUTType @@ -51,6 +51,22 @@ def test_knows(sut_factory): assert sut_factory.knows(UNKNOWN_UID) is False +def test_load_dynamic_sut_factories(): + class MyDriverFactory(FakeDynamicFactory, DynamicDriverSUTFactory): + DRIVER_NAME = "driver-1" + + class OtherDriverFactory(FakeDynamicFactory, DynamicDriverSUTFactory): + DRIVER_NAME = "driver-2" + + with patch("modelgauge.sut_factory.get_concrete_subclasses", return_value=[MyDriverFactory, OtherDriverFactory]): + sut_factory = SUTFactory({}) + + assert sut_factory.dynamic_sut_factories is not None + assert len(sut_factory.dynamic_sut_factories) == 2 + assert isinstance(sut_factory.dynamic_sut_factories["driver-1"], MyDriverFactory) + assert isinstance(sut_factory.dynamic_sut_factories["driver-2"], OtherDriverFactory) + + def test_get_missing_dependencies_dynamic(sut_factory): assert sut_factory.get_missing_dependencies(DYNAMIC_UID, secrets={}) == []