Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/modelgauge/sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from modelgauge.suts.google_sut_factory import GoogleSUTFactory
from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory
from modelgauge.suts.indirect_sut import IndirectSUTFactory
from modelgauge.suts.mistral_sut_factory import MistralSUTFactory
from modelgauge.suts.modelship_sut import ModelShipSUTFactory
from modelgauge.suts.openai_sut_factory import OpenAICompatibleSUTFactory
from modelgauge.suts.together_sut_factory import TogetherSUTFactory
Expand All @@ -36,6 +37,7 @@ class SUTType(Enum):
"hfrelay": HuggingFaceSUTFactory,
"indirect": IndirectSUTFactory,
"openai": OpenAICompatibleSUTFactory,
"mistral": MistralSUTFactory,
"modelship": ModelShipSUTFactory,
"together": TogetherSUTFactory,
}
Expand Down
10 changes: 4 additions & 6 deletions src/modelgauge/suts/mistral_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ def description(cls) -> SecretDescription:


class MistralAIClient:
def __init__(
self,
model_name: str,
api_key: MistralAIAPIKey,
):
self.model_name = model_name
def __init__(self, api_key: MistralAIAPIKey):
self.api_key = api_key.value
self._client = None

Expand Down Expand Up @@ -64,6 +59,9 @@ def _make_request(endpoint, kwargs: dict):
except Exception as exc:
raise (exc)

def model_info(self, model):
return self._make_request(self.client.models.retrieve, {"model_id": model})

def request(self, req: dict):
if self.client.chat.sdk_configuration._hooks.before_request_hooks:
# work around bug in client
Expand Down
11 changes: 2 additions & 9 deletions src/modelgauge/suts/mistral_sut.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from mistralai.models import ChatCompletionResponse, ClassificationResponse, SDKError
from mistralai.models import ChatCompletionResponse, SDKError
from pydantic import BaseModel

from modelgauge.prompt import TextPrompt
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
@property
def client(self):
if not self._client:
self._client = MistralAIClient(self.model_name, self._api_key)
self._client = MistralAIClient(self._api_key)
return self._client

def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> MistralAIRequest:
Expand All @@ -73,13 +73,6 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo
return SUTResponse(text=str(text))


class MistralAIResponseWithModerations(BaseModel):
"""Mistral's ChatCompletionResponse object + moderation scores."""

response: ChatCompletionResponse # Contains multiple completions.
moderations: dict[int, ClassificationResponse] # Keys correspond to a choice's index field.


def register_suts_for_model(model_name):
MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey)
# Register standard SUT.
Expand Down
33 changes: 33 additions & 0 deletions src/modelgauge/suts/mistral_sut_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
from modelgauge.secret_values import InjectSecret, RawSecrets
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.mistral_client import MistralAIAPIKey, MistralAIClient
from modelgauge.suts.mistral_sut import MistralAISut


class MistralSUTFactory(DynamicSUTFactory):
def __init__(self, raw_secrets: RawSecrets):
super().__init__(raw_secrets)
self._client = None # Lazy load.

@property
def client(self) -> MistralAIClient:
if self._client is None:
api_key = self.injected_secrets()[0]
self._client = MistralAIClient(api_key)
return self._client

def get_secrets(self) -> list[InjectSecret]:
api_key = InjectSecret(MistralAIAPIKey)
return [api_key]

def make_sut(self, sut_definition: SUTDefinition) -> SUT:
model_name = sut_definition.to_dynamic_sut_metadata().external_model_name()

try:
self.client.model_info(model_name)
except Exception as e:
raise ModelNotSupportedError(f"Model {model_name} not found or not available on mistral: {e}")

return MistralAISut(sut_definition.dynamic_uid, model_name, *self.injected_secrets())
30 changes: 30 additions & 0 deletions tests/modelgauge_tests/sut_tests/test_mistral_sut_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from unittest.mock import patch

from modelgauge.dynamic_sut_factory import ModelNotSupportedError
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.mistral_sut import MistralAISut
from modelgauge.suts.mistral_sut_factory import MistralSUTFactory


@pytest.fixture
def factory():
return MistralSUTFactory({"mistralai": {"api_key": "value"}})


def test_make_sut(factory):
with patch("modelgauge.suts.mistral_client.MistralAIClient.model_info", return_value="model exists"):
sut_definition = SUTDefinition(model="bar", maker="foo", driver="mistral")
sut = factory.make_sut(sut_definition)

assert isinstance(sut, MistralAISut)
assert sut.uid == "foo/bar:mistral"
assert sut.model_name == "foo/bar"
assert sut._api_key.value == "value"


def test_make_sut_bad_model(factory):
sut_definition = SUTDefinition(model="bogus", maker="fake", driver="mistral")
with patch("modelgauge.suts.mistral_client.MistralAIClient.model_info", side_effect=Exception()):
with pytest.raises(ModelNotSupportedError):
factory.make_sut(sut_definition)