Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 7 additions & 28 deletions docs/add-a-new-sut-driver.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,8 @@ Most providers need their own driver. We provide several drivers that can be use

### Does an Existing Driver Exist?

If your SUT provider is listed as a key in the `DYNAMIC_SUT_FACTORIES` in
[sut_factory](../src/modelgauge/sut_factory.py), you don't need to write any code.

```python
DYNAMIC_SUT_FACTORIES: dict = {
"hf": HuggingFaceSUTFactory,
"hfrelay": HuggingFaceSUTFactory,
"openai": OpenAICompatibleSUTFactory,
"together": TogetherSUTFactory,
"modelship": ModelShipSUTFactory,
}
```
Please refer to [suts-how-to.md](./suts-how-to.md#existing) for details.
Search the existing `DynamicSUTFactoryDriver` classes;
if one already exists for your provider, you can don't need to write any code.

### Is Your SUT Already Pre-Defined?

Expand Down Expand Up @@ -79,13 +68,11 @@ class MySUT(PromptResponseSUT):
return MySUTResponse(**response_json)
```

2. Create a factory class that creates an instance of your SUT from its UID. Look at [TogetherSUTFactory](../src/modelgauge/suts/together_sut_factory.py) for inspiration.
2. Create a factory class that creates an instance of your SUT from its UID. Look at [TogetherSUTFactoryDriver](../src/modelgauge/suts/together_sut_factory.py) for inspiration.

The `DRIVER_NAME` constant must be unique to your driver. It will be a key in a dict.

```python
DRIVER_NAME = "my_sut"

class MySUTApiKey(RequiredSecret):
# adjust this to your specific provider
@classmethod
Expand All @@ -95,7 +82,9 @@ class MySUTApiKey(RequiredSecret):
key="api_key"
)

class MySUTFactory(DynamicSUTFactory):
class MySUTFactory(DynamicSUTFactoryDriver):
DRIVER_NAME = "my_sut"

def __init__(self, raw_secrets: RawSecrets):
# RawSecrets is a dict of secrets
super().__init__(raw_secrets)
Expand All @@ -112,17 +101,7 @@ class MySUTFactory(DynamicSUTFactory):
)
```

3. Add an entry for your new factory class in the `DYNAMIC_SUT_FACTORIES` dict in [sut_factory](../src/modelgauge/sut_factory.py).

```python
DYNAMIC_SUT_FACTORIES: dict = {
...
"my_sut": MySUTFactory,
...
}
```

4. Add a scope to [config/secrets.toml](../config/secrets.toml) for your provider, using the `scope` you defined in the `Secret` class(es) for your SUT:
3. Add a scope to [config/secrets.toml](../config/secrets.toml) for your provider, using the `scope` you defined in the `Secret` class(es) for your SUT:

```toml
[my_host]
Expand Down
21 changes: 0 additions & 21 deletions docs/suts-how-to.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,6 @@ A lot of new SUTs will require no code if your model is hosted on one of the pro

Factory classes are used to create SUT objects for you, including their driver and model name, based the elements in the SUT UID.

Available drivers are identified in `DYNAMIC_SUT_FACTORIES` in
[sut_factory](../src/modelgauge/sut_factory.py). The keys correspond to the `driver` string in the SUT UID.

We may add more drivers from time to time.

```python
DYNAMIC_SUT_FACTORIES: dict = {
"hf": HuggingFaceSUTFactory,
"hfrelay": HuggingFaceSUTFactory,
"openai": OpenAICompatibleSUTFactory,
"together": TogetherSUTFactory,
"modelship": ModelShipSUTFactory,
}
```

* "hf" is used for models hosted by Huggingface
* "hfrelay" is used for models hosted by one of Huggingface's inference provider partners (e.g. nebius, sambanova) and proxied by Huggingface ([more info](https://huggingface.co/docs/inference-providers/en/index))
* "openai" is a model hosted by OpenAI
* "together" is a model hosted by together.ai
* "modelship" is internal to MLCommons

#### Usage

For models using one of those drivers, all you need is to add your credentials to [config/secrets.toml](../config/secrets.toml) in a section named after the driver name string, e.g. for together.ai:
Expand Down
12 changes: 12 additions & 0 deletions src/modelgauge/dynamic_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,15 @@ def get_secrets(self) -> list[InjectSecret]:
def make_sut(self, sut_definition: SUTDefinition) -> SUT:
"""Factories that handle special SUT config parameters (e.g. moderated, reasoning) must accept them as kwargs."""
pass


class DynamicDriverSUTFactory(DynamicSUTFactory, ABC):
"""These classes will be collected as driver factories for dynamic SUTs. They may call regular DynamicSUTFactories."""

DRIVER_NAME: str

def __init__(self, raw_secrets: RawSecrets):
super().__init__(raw_secrets)
assert (
hasattr(self, "DRIVER_NAME") and isinstance(self.DRIVER_NAME, str) and len(self.DRIVER_NAME) > 0
), "DynamicDriverSUTFactory subclasses must have a str DRIVER_NAME attribute"
46 changes: 12 additions & 34 deletions src/modelgauge/sut_factory.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
from enum import Enum

from modelgauge.config import load_secrets_from_config
from modelgauge.dynamic_sut_factory import DynamicSUTFactory, UnknownSUTMakerError
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, UnknownSUTMakerError
from modelgauge.general import get_concrete_subclasses
from modelgauge.load_namespaces import load_namespace
from modelgauge.secret_values import RawSecrets
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.sut_registry import SUTS
from modelgauge.suts.anthropic_sut_factory import AnthropicSUTFactory
from modelgauge.suts.aws_bedrock_sut_factory import AWSBedrockSUTFactory
from modelgauge.suts.google_sut_factory import GoogleSUTFactory
from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory
from modelgauge.suts.indirect_sut import IndirectSUTFactory
from modelgauge.suts.meta_llama_factory import LlamaSUTFactory
from modelgauge.suts.mistral_sut_factory import MistralSUTFactory
from modelgauge.suts.modelship_sut import ModelShipSUTFactory
from modelgauge.suts.openai_sut_factory import OpenAICompatibleSUTFactory
from modelgauge.suts.together_sut_factory import TogetherSUTFactory


class SUTNotFoundException(Exception):
Expand All @@ -32,24 +24,6 @@ class SUTType(Enum):
UNKNOWN = "unknown"


# TODO: Auto-collect?
# Make sure the factory module includes the matching key as a constant.
# Maps a string to the module and factory function in that module
# that can be used to create a dynamic sut
DYNAMIC_SUT_FACTORIES: dict = {
"anthropic": AnthropicSUTFactory,
"aws": AWSBedrockSUTFactory,
"google": GoogleSUTFactory,
"hf": HuggingFaceSUTFactory,
"hfrelay": HuggingFaceSUTFactory,
"indirect": IndirectSUTFactory,
"llama": LlamaSUTFactory,
"openai": OpenAICompatibleSUTFactory,
"mistral": MistralSUTFactory,
"modelship": ModelShipSUTFactory,
"together": TogetherSUTFactory,
}

LEGACY_SUT_MODULE_MAP = {
# HuggingFaceChatCompletionDedicatedSUT and HuggingFaceChatCompletionServerlessSUT
"nvidia-llama-3-1-nemotron-nano-8b-v1": "huggingface_chat_completion",
Expand Down Expand Up @@ -157,11 +131,15 @@ def __init__(self, sut_registry):
self.sut_registry = sut_registry
self.dynamic_sut_factories = self._load_dynamic_sut_factories(load_secrets_from_config())

def _load_dynamic_sut_factories(self, secrets: RawSecrets) -> dict[str, DynamicSUTFactory]:
factories: dict[str, DynamicSUTFactory] = {}
for driver, factory_class in DYNAMIC_SUT_FACTORIES.items():
factories[driver] = factory_class(secrets)
return factories
def _load_dynamic_sut_factories(self, secrets: RawSecrets) -> dict[str, DynamicDriverSUTFactory]:
load_namespace("suts")
dynamic_sut_factories = {}
for cls in get_concrete_subclasses(DynamicDriverSUTFactory): # type: ignore
if cls.DRIVER_NAME in dynamic_sut_factories:
raise ValueError(f"Multiple DynamicSUTFactoryDrivers have the same DRIVER_NAME '{cls.DRIVER_NAME}'.")
dynamic_sut_factories[cls.DRIVER_NAME] = cls(secrets)

return dynamic_sut_factories

def knows(self, uid: str) -> bool:
"""Check if the registry knows about a given SUT UID. Dynamic SUTs are always considered known."""
Expand Down
6 changes: 4 additions & 2 deletions src/modelgauge/suts/anthropic_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

from anthropic import Anthropic

from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError
from modelgauge.secret_values import RawSecrets, InjectSecret
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.anthropic_api import AnthropicApiKey, AnthropicSUT


class AnthropicSUTFactory(DynamicSUTFactory):
class AnthropicSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "anthropic"

def get_secrets(self) -> list[InjectSecret]:
api_key = InjectSecret(AnthropicApiKey)
return [api_key]
Expand Down
4 changes: 2 additions & 2 deletions src/modelgauge/suts/aws_bedrock_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import boto3

from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError
from modelgauge.secret_values import InjectSecret, RawSecrets
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.aws_bedrock_client import AmazonBedrockSut, AwsAccessKeyId, AwsSecretAccessKey


class AWSBedrockSUTFactory(DynamicSUTFactory):
class AWSBedrockSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "aws"

def __init__(self, raw_secrets: RawSecrets):
Expand Down
6 changes: 3 additions & 3 deletions src/modelgauge/suts/google_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

from google import genai

from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError
from modelgauge.secret_values import RawSecrets, InjectSecret
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.google_genai import GoogleGenAiSUT, GoogleAiApiKey

DRIVER_NAME = "google"

class GoogleSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "google"

class GoogleSUTFactory(DynamicSUTFactory):
def get_secrets(self) -> list[InjectSecret]:
api_key = InjectSecret(GoogleAiApiKey)
return [api_key]
Expand Down
15 changes: 9 additions & 6 deletions src/modelgauge/suts/huggingface_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from airrlogger.log_config import get_logger

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError, ProviderNotFoundError
from modelgauge.dynamic_sut_factory import (
DynamicSUTFactory,
DynamicDriverSUTFactory,
ModelNotSupportedError,
ProviderNotFoundError,
)
from modelgauge.secret_values import InjectSecret, RawSecrets
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.huggingface_chat_completion import (
Expand All @@ -13,15 +18,15 @@
HuggingFaceChatCompletionServerlessSUT,
)

DRIVER_NAME = "hfrelay"

logger = get_logger(__name__)
# Set HF logging to ERROR because its default logger level is DEBUG.
# There are also many warnings which are not really actionable and very repetitive.
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)


class HuggingFaceSUTFactory(DynamicSUTFactory):
class HuggingFaceSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "hf"

def __init__(self, raw_secrets: RawSecrets):
super().__init__(raw_secrets)
self.serverless_factory = HuggingFaceChatCompletionServerlessSUTFactory(raw_secrets)
Expand All @@ -45,7 +50,6 @@ def make_sut(self, sut_definition: SUTDefinition) -> BaseHuggingFaceChatCompleti


class HuggingFaceChatCompletionServerlessSUTFactory(DynamicSUTFactory):

def get_secrets(self) -> list[InjectSecret]:
hf_token = InjectSecret(HuggingFaceInferenceToken)
return [hf_token]
Expand Down Expand Up @@ -89,7 +93,6 @@ def make_sut(self, sut_definition: SUTDefinition) -> HuggingFaceChatCompletionSe


class HuggingFaceChatCompletionDedicatedSUTFactory(DynamicSUTFactory):

def get_secrets(self) -> list[InjectSecret]:
hf_token = InjectSecret(HuggingFaceInferenceToken)
return [hf_token]
Expand Down
5 changes: 3 additions & 2 deletions src/modelgauge/suts/indirect_sut.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uvicorn
from pydantic import BaseModel

from modelgauge.dynamic_sut_factory import DynamicSUTFactory
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory
from modelgauge.prompt import TextPrompt
from modelgauge.ready import ReadyResponse
from modelgauge.secret_values import InjectSecret
Expand Down Expand Up @@ -137,7 +137,8 @@ def start():
thread.start()


class IndirectSUTFactory(DynamicSUTFactory):
class IndirectSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "indirect"

def get_secrets(self) -> list[InjectSecret]:
return []
Expand Down
6 changes: 4 additions & 2 deletions src/modelgauge/suts/meta_llama_factory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from llama_api_client import LlamaAPIClient

from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError
from modelgauge.secret_values import InjectSecret, RawSecrets
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.meta_llama_client import MetaLlamaApiKey, MetaLlamaModeratedSUT, MetaLlamaSUT


class LlamaSUTFactory(DynamicSUTFactory):
class LlamaSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "llama"

def __init__(self, raw_secrets: RawSecrets):
super().__init__(raw_secrets)
self._client = None
Expand Down
6 changes: 4 additions & 2 deletions src/modelgauge/suts/mistral_sut_factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory, ModelNotSupportedError
from modelgauge.secret_values import InjectSecret, RawSecrets
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.mistral_client import MistralAIAPIKey, MistralAIClient
from modelgauge.suts.mistral_sut import MistralAISut


class MistralSUTFactory(DynamicSUTFactory):
class MistralSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "mistral"

def __init__(self, raw_secrets: RawSecrets):
super().__init__(raw_secrets)
self._client = None # Lazy load.
Expand Down
6 changes: 4 additions & 2 deletions src/modelgauge/suts/modelship_sut.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Mapping, Any

from modelgauge.auth.openai_compatible_secrets import OpenAICompatibleApiKey
from modelgauge.dynamic_sut_factory import DynamicSUTFactory
from modelgauge.dynamic_sut_factory import DynamicDriverSUTFactory
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.openai_client import OpenAIChat, OpenAIChatRequest
Expand Down Expand Up @@ -34,7 +34,9 @@ def request_as_dict_for_client(self, request: OpenAIChatRequest) -> dict[str, An
return request_as_dict


class ModelShipSUTFactory(DynamicSUTFactory):
class ModelShipSUTFactory(DynamicDriverSUTFactory):
DRIVER_NAME = "modelship"

def get_secrets(self) -> list[InjectSecret]:
api_key = InjectSecret(ModelShipSecret)
return [api_key]
Expand Down
Loading
Loading