Skip to content

Commit a80d274

Browse files
authored
Dynamic mistral SUTs (#1493)
* mistral dynamic factory * include mistral in sut factory
1 parent d093dcb commit a80d274

File tree

5 files changed

+71
-15
lines changed

5 files changed

+71
-15
lines changed

src/modelgauge/sut_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from modelgauge.suts.google_sut_factory import GoogleSUTFactory
1111
from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory
1212
from modelgauge.suts.indirect_sut import IndirectSUTFactory
13+
from modelgauge.suts.mistral_sut_factory import MistralSUTFactory
1314
from modelgauge.suts.modelship_sut import ModelShipSUTFactory
1415
from modelgauge.suts.openai_sut_factory import OpenAICompatibleSUTFactory
1516
from modelgauge.suts.together_sut_factory import TogetherSUTFactory
@@ -36,6 +37,7 @@ class SUTType(Enum):
3637
"hfrelay": HuggingFaceSUTFactory,
3738
"indirect": IndirectSUTFactory,
3839
"openai": OpenAICompatibleSUTFactory,
40+
"mistral": MistralSUTFactory,
3941
"modelship": ModelShipSUTFactory,
4042
"together": TogetherSUTFactory,
4143
}

src/modelgauge/suts/mistral_client.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,7 @@ def description(cls) -> SecretDescription:
2121

2222

2323
class MistralAIClient:
24-
def __init__(
25-
self,
26-
model_name: str,
27-
api_key: MistralAIAPIKey,
28-
):
29-
self.model_name = model_name
24+
def __init__(self, api_key: MistralAIAPIKey):
3025
self.api_key = api_key.value
3126
self._client = None
3227

@@ -64,6 +59,9 @@ def _make_request(endpoint, kwargs: dict):
6459
except Exception as exc:
6560
raise (exc)
6661

62+
def model_info(self, model):
63+
return self._make_request(self.client.models.retrieve, {"model_id": model})
64+
6765
def request(self, req: dict):
6866
if self.client.chat.sdk_configuration._hooks.before_request_hooks:
6967
# work around bug in client

src/modelgauge/suts/mistral_sut.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Optional
22

3-
from mistralai.models import ChatCompletionResponse, ClassificationResponse, SDKError
3+
from mistralai.models import ChatCompletionResponse, SDKError
44
from pydantic import BaseModel
55

66
from modelgauge.prompt import TextPrompt
@@ -50,7 +50,7 @@ def __init__(
5050
@property
5151
def client(self):
5252
if not self._client:
53-
self._client = MistralAIClient(self.model_name, self._api_key)
53+
self._client = MistralAIClient(self._api_key)
5454
return self._client
5555

5656
def translate_text_prompt(self, prompt: TextPrompt, options: ModelOptions) -> MistralAIRequest:
@@ -73,13 +73,6 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo
7373
return SUTResponse(text=str(text))
7474

7575

76-
class MistralAIResponseWithModerations(BaseModel):
77-
"""Mistral's ChatCompletionResponse object + moderation scores."""
78-
79-
response: ChatCompletionResponse # Contains multiple completions.
80-
moderations: dict[int, ClassificationResponse] # Keys correspond to a choice's index field.
81-
82-
8376
def register_suts_for_model(model_name):
8477
MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey)
8578
# Register standard SUT.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
2+
from modelgauge.secret_values import InjectSecret, RawSecrets
3+
from modelgauge.sut import SUT
4+
from modelgauge.sut_definition import SUTDefinition
5+
from modelgauge.suts.mistral_client import MistralAIAPIKey, MistralAIClient
6+
from modelgauge.suts.mistral_sut import MistralAISut
7+
8+
9+
class MistralSUTFactory(DynamicSUTFactory):
10+
def __init__(self, raw_secrets: RawSecrets):
11+
super().__init__(raw_secrets)
12+
self._client = None # Lazy load.
13+
14+
@property
15+
def client(self) -> MistralAIClient:
16+
if self._client is None:
17+
api_key = self.injected_secrets()[0]
18+
self._client = MistralAIClient(api_key)
19+
return self._client
20+
21+
def get_secrets(self) -> list[InjectSecret]:
22+
api_key = InjectSecret(MistralAIAPIKey)
23+
return [api_key]
24+
25+
def make_sut(self, sut_definition: SUTDefinition) -> SUT:
26+
model_name = sut_definition.to_dynamic_sut_metadata().external_model_name()
27+
28+
try:
29+
self.client.model_info(model_name)
30+
except Exception as e:
31+
raise ModelNotSupportedError(f"Model {model_name} not found or not available on mistral: {e}")
32+
33+
return MistralAISut(sut_definition.dynamic_uid, model_name, *self.injected_secrets())
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from unittest.mock import patch
3+
4+
from modelgauge.dynamic_sut_factory import ModelNotSupportedError
5+
from modelgauge.sut_definition import SUTDefinition
6+
from modelgauge.suts.mistral_sut import MistralAISut
7+
from modelgauge.suts.mistral_sut_factory import MistralSUTFactory
8+
9+
10+
@pytest.fixture
11+
def factory():
12+
return MistralSUTFactory({"mistralai": {"api_key": "value"}})
13+
14+
15+
def test_make_sut(factory):
16+
with patch("modelgauge.suts.mistral_client.MistralAIClient.model_info", return_value="model exists"):
17+
sut_definition = SUTDefinition(model="bar", maker="foo", driver="mistral")
18+
sut = factory.make_sut(sut_definition)
19+
20+
assert isinstance(sut, MistralAISut)
21+
assert sut.uid == "foo/bar:mistral"
22+
assert sut.model_name == "foo/bar"
23+
assert sut._api_key.value == "value"
24+
25+
26+
def test_make_sut_bad_model(factory):
27+
sut_definition = SUTDefinition(model="bogus", maker="fake", driver="mistral")
28+
with patch("modelgauge.suts.mistral_client.MistralAIClient.model_info", side_effect=Exception()):
29+
with pytest.raises(ModelNotSupportedError):
30+
factory.make_sut(sut_definition)

0 commit comments

Comments
 (0)