Skip to content

Commit 882f68c

Browse files
author
Andrew Bernat
committed
Add tests for function list_bedrock_models.
1 parent 2831150 commit 882f68c

File tree

4 files changed

+213
-14
lines changed

4 files changed

+213
-14
lines changed

.github/workflows/aws-genai-cicd-suite.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@ jobs:
2525
- name: Checkout code
2626
uses: actions/checkout@v3
2727

28-
- name: Set up Node.js
29-
uses: actions/setup-node@v3
28+
- name: Set up Python
29+
uses: actions/setup-python@v2
3030
with:
31-
node-version: '20'
31+
python-version: 3.12 # Adjust the Python version as needed
3232

33-
- name: Install dependencies @actions/core and @actions/github
34-
run: |
35-
npm install @actions/core
36-
npm install @actions/github
37-
shell: bash
33+
- name: Install dependencies
34+
run: pip install -r requirements.txt
35+
36+
- name: Test
37+
run: python -m unittest
38+
working-directory: ./tests
3839

3940
# check if required dependencies @actions/core and @actions/github are installed
4041
- name: Check if required dependencies are installed

src/api/models/bedrock.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import re
55
import time
6-
from abc import ABC
6+
from abc import ABC, abstractmethod
77
from typing import AsyncIterable, Iterable, Literal
88

99
import boto3
@@ -75,8 +75,27 @@ def get_inference_region_prefix():
7575

7676
ENCODER = tiktoken.get_encoding("cl100k_base")
7777

78+
class BedrockClientInterface(ABC):
79+
@abstractmethod
80+
def list_inference_profiles(self, **kwargs) -> dict:
81+
pass
7882

79-
def list_bedrock_models() -> dict:
83+
@abstractmethod
84+
def list_foundation_models(self, **kwargs) -> dict:
85+
pass
86+
87+
class BedrockClient(BedrockClientInterface):
88+
def __init__(self, client):
89+
self.bedrock_client = client
90+
91+
def list_inference_profiles(self, **kwargs) -> dict:
92+
return self.bedrock_client.list_inference_profiles(**kwargs)
93+
94+
def list_foundation_models(self, **kwargs) -> dict:
95+
return self.bedrock_client.list_foundation_models(**kwargs)
96+
97+
98+
def list_bedrock_models(client : BedrockClientInterface) -> dict:
8099
"""Automatically getting a list of supported models.
81100
82101
Returns a model list combines:
@@ -88,14 +107,14 @@ def list_bedrock_models() -> dict:
88107
profile_list = []
89108
if ENABLE_CROSS_REGION_INFERENCE:
90109
# List system defined inference profile IDs
91-
response = bedrock_client.list_inference_profiles(
110+
response = client.list_inference_profiles(
92111
maxResults=1000,
93112
typeEquals='SYSTEM_DEFINED'
94113
)
95114
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]
96115

97116
# List foundation models, only cares about text outputs here.
98-
response = bedrock_client.list_foundation_models(
117+
response = client.list_foundation_models(
99118
byOutputModality='TEXT'
100119
)
101120

@@ -136,15 +155,15 @@ def list_bedrock_models() -> dict:
136155

137156

138157
# Initialize the model list.
139-
bedrock_model_list = list_bedrock_models()
158+
bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client))
140159

141160

142161
class BedrockModel(BaseChatModel):
143162

144163
def list_models(self) -> list[str]:
145164
"""Always refresh the latest model list"""
146165
global bedrock_model_list
147-
bedrock_model_list = list_bedrock_models()
166+
bedrock_model_list = list_bedrock_models(BedrockClient(bedrock_client))
148167
return list(bedrock_model_list.keys())
149168

150169
def validate(self, chat_request: ChatRequest):

tests/__init__.py

Whitespace-only changes.

tests/list_bedrock_models_test.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from typing import Literal
2+
3+
from src.api.models.bedrock import list_bedrock_models, BedrockClientInterface
4+
5+
def test_default_model():
6+
client = FakeBedrockClient(
7+
inference_profile("p1-id", "p1", "SYSTEM_DEFINED"),
8+
inference_profile("p2-id", "p2", "APPLICATION"),
9+
inference_profile("p3-id", "p3", "SYSTEM_DEFINED"),
10+
)
11+
12+
models = list_bedrock_models(client)
13+
14+
assert models == {
15+
"anthropic.claude-3-sonnet-20240229-v1:0": {
16+
"modalities": ["TEXT", "IMAGE"]
17+
}
18+
}
19+
20+
def test_one_model():
21+
client = FakeBedrockClient(
22+
model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT", "IMAGE"])
23+
)
24+
25+
models = list_bedrock_models(client)
26+
27+
assert models == {
28+
"model-id": {
29+
"modalities": ["TEXT", "IMAGE"]
30+
}
31+
}
32+
33+
def test_two_models():
34+
client = FakeBedrockClient(
35+
model("model-id-1", "model-name-1", stream_supported=True, input_modalities=["TEXT", "IMAGE"]),
36+
model("model-id-2", "model-name-2", stream_supported=True, input_modalities=["IMAGE"])
37+
)
38+
39+
models = list_bedrock_models(client)
40+
41+
assert models == {
42+
"model-id-1": {
43+
"modalities": ["TEXT", "IMAGE"]
44+
},
45+
"model-id-2": {
46+
"modalities": ["IMAGE"]
47+
}
48+
}
49+
50+
def test_filter_models():
51+
client = FakeBedrockClient(
52+
model("model-id", "model-name-1", stream_supported=True, input_modalities=["TEXT"], status="LEGACY"),
53+
model("model-id-no-stream", "model-name-2", stream_supported=False, input_modalities=["TEXT", "IMAGE"]),
54+
model("model-id-not-active", "model-name-3", stream_supported=True, status="DISABLED"),
55+
model("model-id-not-text-output", "model-name-4", stream_supported=True, output_modalities=["IMAGE"])
56+
)
57+
58+
models = list_bedrock_models(client)
59+
60+
assert models == {
61+
"model-id": {
62+
"modalities": ["TEXT"]
63+
}
64+
}
65+
66+
def test_one_inference_profile():
67+
client = FakeBedrockClient(
68+
inference_profile("us.model-id", "p1", "SYSTEM_DEFINED"),
69+
model("model-id", "model-name", stream_supported=True, input_modalities=["TEXT"])
70+
)
71+
72+
models = list_bedrock_models(client)
73+
74+
assert models == {
75+
"model-id": {
76+
"modalities": ["TEXT"]
77+
},
78+
"us.model-id": {
79+
"modalities": ["TEXT"]
80+
}
81+
}
82+
83+
def test_default_model_on_throw():
84+
client = ThrowingBedrockClient()
85+
86+
models = list_bedrock_models(client)
87+
88+
assert models == {
89+
"anthropic.claude-3-sonnet-20240229-v1:0": {
90+
"modalities": ["TEXT", "IMAGE"]
91+
}
92+
}
93+
94+
def inference_profile(profile_id: str, name: str, profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"]):
95+
return {
96+
"inferenceProfileName": name,
97+
"inferenceProfileId": profile_id,
98+
"type": profile_type
99+
}
100+
101+
def model(
102+
model_id: str,
103+
model_name: str,
104+
input_modalities: list[str] = None,
105+
output_modalities: list[str] = None,
106+
stream_supported: bool = False,
107+
inference_types: list[str] = None,
108+
status: str = "ACTIVE") -> dict:
109+
if input_modalities is None:
110+
input_modalities = ["TEXT"]
111+
if output_modalities is None:
112+
output_modalities = ["TEXT"]
113+
if inference_types is None:
114+
inference_types = ["ON_DEMAND"]
115+
return {
116+
"modelArn": "arn:model:" + model_id,
117+
"modelId": model_id,
118+
"modelName": model_name,
119+
"providerName": "anthropic",
120+
"inputModalities":input_modalities,
121+
"outputModalities": output_modalities,
122+
"responseStreamingSupported": stream_supported,
123+
"customizationsSupported": ["FINE_TUNING"],
124+
"inferenceTypesSupported": inference_types,
125+
"modelLifecycle": {
126+
"status": status
127+
}
128+
}
129+
130+
def _filter_inference_profiles(inference_profiles: list[dict], profile_type: Literal["SYSTEM_DEFINED", "APPLICATION"], max_results: int = 100):
131+
return [p for p in inference_profiles if p.get("type") == profile_type][:max_results]
132+
133+
def _filter_models(
134+
models: list[dict],
135+
provider_name: str | None,
136+
customization_type: Literal["FINE_TUNING","CONTINUED_PRE_TRAINING","DISTILLATION"] | None,
137+
output_modality: Literal["TEXT","IMAGE","EMBEDDING"] | None,
138+
inference_type: Literal["ON_DEMAND","PROVISIONED"] | None):
139+
return [m for m in models if
140+
(provider_name is None or m.get("providerName") == provider_name) and
141+
(output_modality is None or output_modality in m.get("outputModalities")) and
142+
(customization_type is None or customization_type in m.get("customizationsSupported")) and
143+
(inference_type is None or inference_type in m.get("inferenceTypesSupported"))
144+
]
145+
146+
class ThrowingBedrockClient(BedrockClientInterface):
147+
def list_inference_profiles(self, **kwargs) -> dict:
148+
raise Exception("throwing bedrock client always throws exception")
149+
def list_foundation_models(self, **kwargs) -> dict:
150+
raise Exception("throwing bedrock client always throws exception")
151+
152+
class FakeBedrockClient(BedrockClientInterface):
153+
def __init__(self, *args):
154+
self.inference_profiles = [p for p in args if p.get("inferenceProfileId", "") != ""]
155+
self.models = [m for m in args if m.get("modelId", "") != ""]
156+
157+
unexpected = [u for u in args if (u.get("modelId", "") == "" and u.get("inferenceProfileId", "") == "")]
158+
if len(unexpected) > 0:
159+
raise Exception("expected a model or a profile")
160+
161+
def list_inference_profiles(self, **kwargs) -> dict:
162+
return {
163+
"inferenceProfileSummaries": _filter_inference_profiles(
164+
self.inference_profiles,
165+
profile_type=kwargs["typeEquals"],
166+
max_results=kwargs.get("maxResults", 100)
167+
)
168+
}
169+
170+
def list_foundation_models(self, **kwargs) -> dict:
171+
return {
172+
"modelSummaries": _filter_models(
173+
self.models,
174+
provider_name=kwargs.get("byProvider", None),
175+
customization_type=kwargs.get("byCustomizationType", None),
176+
output_modality=kwargs.get("byOutputModality", None),
177+
inference_type=kwargs.get("byInferenceType", None)
178+
)
179+
}

0 commit comments

Comments
 (0)