Skip to content

Commit 4f7d18b

Browse files
authored
Aws sut factory (#1499)
* aws sut factory * fix matching * rename nova sut * fix ID parsing, only include active models * tests * Fix query * get region name from env
1 parent 58f6b7a commit 4f7d18b

File tree

5 files changed

+137
-6
lines changed

5 files changed

+137
-6
lines changed

src/modelgauge/sut_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from modelgauge.sut_definition import SUTDefinition
88
from modelgauge.sut_registry import SUTS
99
from modelgauge.suts.anthropic_sut_factory import AnthropicSUTFactory
10+
from modelgauge.suts.aws_bedrock_sut_factory import AWSBedrockSUTFactory
1011
from modelgauge.suts.google_sut_factory import GoogleSUTFactory
1112
from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory
1213
from modelgauge.suts.indirect_sut import IndirectSUTFactory
@@ -37,6 +38,7 @@ class SUTType(Enum):
3738
# that can be used to create a dynamic sut
3839
DYNAMIC_SUT_FACTORIES: dict = {
3940
"anthropic": AnthropicSUTFactory,
41+
"aws": AWSBedrockSUTFactory,
4042
"google": GoogleSUTFactory,
4143
"hf": HuggingFaceSUTFactory,
4244
"hfrelay": HuggingFaceSUTFactory,

src/modelgauge/suts/aws_bedrock_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# as defined here:
22
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
33

4+
import os
45
from typing import Any, Dict, List, Optional
56

67
import boto3
@@ -106,7 +107,7 @@ class BedrockResponseUsage(BaseModel):
106107

107108

108109
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
109-
class AmazonNovaSut(PromptResponseSUT):
110+
class AmazonBedrockSut(PromptResponseSUT):
110111

111112
def __init__(self, uid: str, model_id: str, access_key_id: AwsAccessKeyId, secret_access_key: AwsSecretAccessKey):
112113
super().__init__(uid)
@@ -118,7 +119,7 @@ def __init__(self, uid: str, model_id: str, access_key_id: AwsAccessKeyId, secre
118119
def _load_client(self):
119120
return boto3.client(
120121
service_name="bedrock-runtime",
121-
region_name="us-east-1",
122+
region_name=os.getenv("AWS_REGION", "us-east-1"),
122123
aws_access_key_id=self.access_key_id,
123124
aws_secret_access_key=self.secret_access_key,
124125
)
@@ -159,7 +160,7 @@ def translate_response(self, request: BedrockRequest, response: BedrockResponse)
159160

160161
for model in BEDROCK_MODELS:
161162
SUTS.register(
162-
AmazonNovaSut,
163+
AmazonBedrockSut,
163164
f"amazon-nova-1.0-{model}",
164165
f"amazon.nova-{model}-v1:0",
165166
InjectSecret(AwsAccessKeyId),
@@ -169,7 +170,7 @@ def translate_response(self, request: BedrockRequest, response: BedrockResponse)
169170
BEDROCK_INFERENCE_PROFILES = ["premier"]
170171
for model in BEDROCK_INFERENCE_PROFILES:
171172
SUTS.register(
172-
AmazonNovaSut,
173+
AmazonBedrockSut,
173174
f"amazon-nova-1.0-{model}",
174175
f"us.amazon.nova-{model}-v1:0",
175176
InjectSecret(AwsAccessKeyId),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
import boto3
4+
5+
from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
6+
from modelgauge.secret_values import InjectSecret, RawSecrets
7+
from modelgauge.sut import SUT
8+
from modelgauge.sut_definition import SUTDefinition
9+
from modelgauge.suts.aws_bedrock_client import AmazonBedrockSut, AwsAccessKeyId, AwsSecretAccessKey
10+
11+
12+
class AWSBedrockSUTFactory(DynamicSUTFactory):
13+
DRIVER_NAME = "aws"
14+
15+
def __init__(self, raw_secrets: RawSecrets):
16+
super().__init__(raw_secrets)
17+
self._client = None # Lazy load.
18+
19+
@property
20+
def client(self):
21+
if self._client is None:
22+
self._client = boto3.client(
23+
service_name="bedrock",
24+
region_name=os.getenv("AWS_REGION", "us-east-1"),
25+
aws_access_key_id=self.injected_secrets()[0].value,
26+
aws_secret_access_key=self.injected_secrets()[1].value,
27+
)
28+
return self._client
29+
30+
def _convert_model_id(self, model_id: str) -> SUTDefinition:
31+
"""Convert AWS model IDs (maker.model[:version?]) to our standard format."""
32+
maker, model_name = model_id.split(".", maxsplit=1)
33+
model_name = model_name.replace(":", ".")
34+
return SUTDefinition({"maker": maker, "model": model_name, "driver": self.DRIVER_NAME})
35+
36+
def _get_available_models(self, maker: str):
37+
response = self.client.list_foundation_models()
38+
models = {}
39+
for m in response["modelSummaries"]:
40+
if m.get("modelLifecycle", {}).get("status") != "ACTIVE":
41+
continue
42+
models[m["modelId"]] = self._convert_model_id(m["modelId"])
43+
return models
44+
45+
def _get_model_id(self, sut_definition: SUTDefinition):
46+
models = self._get_available_models(sut_definition.to_dynamic_sut_metadata().maker)
47+
for model_id, model_definition in models.items():
48+
if str(model_definition.to_dynamic_sut_metadata()) == str(sut_definition.to_dynamic_sut_metadata()):
49+
return model_id
50+
supported_models = [model_def.to_dynamic_sut_metadata().external_model_name() for model_def in models.values()]
51+
raise ModelNotSupportedError(
52+
f"Model {sut_definition.external_model_name()} not found among AWS Bedrock models. AWS carries the following models from maker {sut_definition.get("maker")}: {supported_models} "
53+
)
54+
55+
def get_secrets(self) -> list[InjectSecret]:
56+
return [InjectSecret(AwsAccessKeyId), InjectSecret(AwsSecretAccessKey)]
57+
58+
def make_sut(self, sut_definition: SUTDefinition) -> SUT:
59+
model_id = self._get_model_id(sut_definition)
60+
return AmazonBedrockSut(sut_definition.dynamic_uid, model_id, *self.injected_secrets())

tests/modelgauge_tests/sut_tests/test_aws_bedrock_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from modelgauge.typed_data import is_typeable
88

99
from modelgauge.suts.aws_bedrock_client import (
10-
AmazonNovaSut,
10+
AmazonBedrockSut,
1111
AwsAccessKeyId,
1212
AwsSecretAccessKey,
1313
BedrockRequest,
@@ -19,7 +19,7 @@
1919

2020
@pytest.fixture
2121
def fake_sut():
22-
return AmazonNovaSut(
22+
return AmazonBedrockSut(
2323
"fake-sut", FAKE_MODEL_ID, AwsAccessKeyId("fake-api-key"), AwsSecretAccessKey("fake-secret-key")
2424
)
2525

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
from unittest.mock import patch
3+
4+
from modelgauge.dynamic_sut_factory import ModelNotSupportedError
5+
from modelgauge.sut_definition import SUTDefinition
6+
from modelgauge.suts.aws_bedrock_client import AmazonBedrockSut
7+
from modelgauge.suts.aws_bedrock_sut_factory import AWSBedrockSUTFactory
8+
9+
10+
@pytest.fixture
11+
def factory():
12+
return AWSBedrockSUTFactory({"aws": {"access_key_id": "value", "secret_access_key": "value"}})
13+
14+
15+
@pytest.fixture
16+
def mock_list_foundation_models():
17+
models = {
18+
"modelSummaries": [
19+
{"modelId": "amazon.nova-1.0-micro-v1:0", "modelLifecycle": {"status": "ACTIVE"}},
20+
{"modelId": "old_model", "modelLifecycle": {"status": "LEGACY"}},
21+
]
22+
}
23+
24+
with patch("boto3.client") as mock_client:
25+
26+
mock_client.return_value.list_foundation_models.return_value = models
27+
28+
yield mock_client
29+
30+
31+
def test_convert_model_id(factory):
32+
definition = factory._convert_model_id("amazon.nova-v1")
33+
assert definition.get("maker") == "amazon"
34+
assert definition.get("model") == "nova-v1"
35+
assert definition.get("driver") == "aws"
36+
37+
# Sometimes they have colons
38+
definition = factory._convert_model_id("amazon.nova-v1:0")
39+
assert definition.get("maker") == "amazon"
40+
assert definition.get("model") == "nova-v1.0"
41+
assert definition.get("driver") == "aws"
42+
43+
# "." in the model name
44+
definition = factory._convert_model_id("moonshotai.kimi-k2.5")
45+
assert definition.get("maker") == "moonshotai"
46+
assert definition.get("model") == "kimi-k2.5"
47+
assert definition.get("driver") == "aws"
48+
49+
50+
def test_make_sut(factory, mock_list_foundation_models):
51+
sut_definition = SUTDefinition(model="nova-1.0-micro-v1.0", maker="amazon", driver="aws")
52+
sut = factory.make_sut(sut_definition)
53+
54+
assert isinstance(sut, AmazonBedrockSut)
55+
assert sut.uid == "amazon/nova-1.0-micro-v1.0:aws"
56+
assert sut.model_id == "amazon.nova-1.0-micro-v1:0"
57+
58+
59+
def test_make_sut_no_model(factory, mock_list_foundation_models):
60+
sut_definition = SUTDefinition(model="unknown", maker="amazon", driver="aws")
61+
with pytest.raises(ModelNotSupportedError):
62+
factory.make_sut(sut_definition)
63+
64+
65+
def test_make_sut_legacy_model(factory, mock_list_foundation_models):
66+
sut_definition = SUTDefinition(model="old_model", maker="amazon", driver="aws")
67+
with pytest.raises(ModelNotSupportedError):
68+
factory.make_sut(sut_definition)

0 commit comments

Comments
 (0)