1+ from typing import Literal
2+
3+ from src .api .models .bedrock import list_bedrock_models , BedrockClientInterface
4+
5+ def test_default_model ():
6+ client = FakeBedrockClient (
7+ inference_profile ("p1-id" , "p1" , "SYSTEM_DEFINED" ),
8+ inference_profile ("p2-id" , "p2" , "APPLICATION" ),
9+ inference_profile ("p3-id" , "p3" , "SYSTEM_DEFINED" ),
10+ )
11+
12+ models = list_bedrock_models (client )
13+
14+ assert models == {
15+ "anthropic.claude-3-sonnet-20240229-v1:0" : {
16+ "modalities" : ["TEXT" , "IMAGE" ]
17+ }
18+ }
19+
20+ def test_one_model ():
21+ client = FakeBedrockClient (
22+ model ("model-id" , "model-name" , stream_supported = True , input_modalities = ["TEXT" , "IMAGE" ])
23+ )
24+
25+ models = list_bedrock_models (client )
26+
27+ assert models == {
28+ "model-id" : {
29+ "modalities" : ["TEXT" , "IMAGE" ]
30+ }
31+ }
32+
33+ def test_two_models ():
34+ client = FakeBedrockClient (
35+ model ("model-id-1" , "model-name-1" , stream_supported = True , input_modalities = ["TEXT" , "IMAGE" ]),
36+ model ("model-id-2" , "model-name-2" , stream_supported = True , input_modalities = ["IMAGE" ])
37+ )
38+
39+ models = list_bedrock_models (client )
40+
41+ assert models == {
42+ "model-id-1" : {
43+ "modalities" : ["TEXT" , "IMAGE" ]
44+ },
45+ "model-id-2" : {
46+ "modalities" : ["IMAGE" ]
47+ }
48+ }
49+
50+ def test_filter_models ():
51+ client = FakeBedrockClient (
52+ model ("model-id" , "model-name-1" , stream_supported = True , input_modalities = ["TEXT" ], status = "LEGACY" ),
53+ model ("model-id-no-stream" , "model-name-2" , stream_supported = False , input_modalities = ["TEXT" , "IMAGE" ]),
54+ model ("model-id-not-active" , "model-name-3" , stream_supported = True , status = "DISABLED" ),
55+ model ("model-id-not-text-output" , "model-name-4" , stream_supported = True , output_modalities = ["IMAGE" ])
56+ )
57+
58+ models = list_bedrock_models (client )
59+
60+ assert models == {
61+ "model-id" : {
62+ "modalities" : ["TEXT" ]
63+ }
64+ }
65+
66+ def test_one_inference_profile ():
67+ client = FakeBedrockClient (
68+ inference_profile ("us.model-id" , "p1" , "SYSTEM_DEFINED" ),
69+ model ("model-id" , "model-name" , stream_supported = True , input_modalities = ["TEXT" ])
70+ )
71+
72+ models = list_bedrock_models (client )
73+
74+ assert models == {
75+ "model-id" : {
76+ "modalities" : ["TEXT" ]
77+ },
78+ "us.model-id" : {
79+ "modalities" : ["TEXT" ]
80+ }
81+ }
82+
83+ def test_default_model_on_throw ():
84+ client = ThrowingBedrockClient ()
85+
86+ models = list_bedrock_models (client )
87+
88+ assert models == {
89+ "anthropic.claude-3-sonnet-20240229-v1:0" : {
90+ "modalities" : ["TEXT" , "IMAGE" ]
91+ }
92+ }
93+
94+ def inference_profile (profile_id : str , name : str , profile_type : Literal ["SYSTEM_DEFINED" , "APPLICATION" ]):
95+ return {
96+ "inferenceProfileName" : name ,
97+ "inferenceProfileId" : profile_id ,
98+ "type" : profile_type
99+ }
100+
101+ def model (
102+ model_id : str ,
103+ model_name : str ,
104+ input_modalities : list [str ] = None ,
105+ output_modalities : list [str ] = None ,
106+ stream_supported : bool = False ,
107+ inference_types : list [str ] = None ,
108+ status : str = "ACTIVE" ) -> dict :
109+ if input_modalities is None :
110+ input_modalities = ["TEXT" ]
111+ if output_modalities is None :
112+ output_modalities = ["TEXT" ]
113+ if inference_types is None :
114+ inference_types = ["ON_DEMAND" ]
115+ return {
116+ "modelArn" : "arn:model:" + model_id ,
117+ "modelId" : model_id ,
118+ "modelName" : model_name ,
119+ "providerName" : "anthropic" ,
120+ "inputModalities" :input_modalities ,
121+ "outputModalities" : output_modalities ,
122+ "responseStreamingSupported" : stream_supported ,
123+ "customizationsSupported" : ["FINE_TUNING" ],
124+ "inferenceTypesSupported" : inference_types ,
125+ "modelLifecycle" : {
126+ "status" : status
127+ }
128+ }
129+
130+ def _filter_inference_profiles (inference_profiles : list [dict ], profile_type : Literal ["SYSTEM_DEFINED" , "APPLICATION" ], max_results : int = 100 ):
131+ return [p for p in inference_profiles if p .get ("type" ) == profile_type ][:max_results ]
132+
133+ def _filter_models (
134+ models : list [dict ],
135+ provider_name : str | None ,
136+ customization_type : Literal ["FINE_TUNING" ,"CONTINUED_PRE_TRAINING" ,"DISTILLATION" ] | None ,
137+ output_modality : Literal ["TEXT" ,"IMAGE" ,"EMBEDDING" ] | None ,
138+ inference_type : Literal ["ON_DEMAND" ,"PROVISIONED" ] | None ):
139+ return [m for m in models if
140+ (provider_name is None or m .get ("providerName" ) == provider_name ) and
141+ (output_modality is None or output_modality in m .get ("outputModalities" )) and
142+ (customization_type is None or customization_type in m .get ("customizationsSupported" )) and
143+ (inference_type is None or inference_type in m .get ("inferenceTypesSupported" ))
144+ ]
145+
146+ class ThrowingBedrockClient (BedrockClientInterface ):
147+ def list_inference_profiles (self , ** kwargs ) -> dict :
148+ raise Exception ("throwing bedrock client always throws exception" )
149+ def list_foundation_models (self , ** kwargs ) -> dict :
150+ raise Exception ("throwing bedrock client always throws exception" )
151+
152+ class FakeBedrockClient (BedrockClientInterface ):
153+ def __init__ (self , * args ):
154+ self .inference_profiles = [p for p in args if p .get ("inferenceProfileId" , "" ) != "" ]
155+ self .models = [m for m in args if m .get ("modelId" , "" ) != "" ]
156+
157+ unexpected = [u for u in args if (u .get ("modelId" , "" ) == "" and u .get ("inferenceProfileId" , "" ) == "" )]
158+ if len (unexpected ) > 0 :
159+ raise Exception ("expected a model or a profile" )
160+
161+ def list_inference_profiles (self , ** kwargs ) -> dict :
162+ return {
163+ "inferenceProfileSummaries" : _filter_inference_profiles (
164+ self .inference_profiles ,
165+ profile_type = kwargs ["typeEquals" ],
166+ max_results = kwargs .get ("maxResults" , 100 )
167+ )
168+ }
169+
170+ def list_foundation_models (self , ** kwargs ) -> dict :
171+ return {
172+ "modelSummaries" : _filter_models (
173+ self .models ,
174+ provider_name = kwargs .get ("byProvider" , None ),
175+ customization_type = kwargs .get ("byCustomizationType" , None ),
176+ output_modality = kwargs .get ("byOutputModality" , None ),
177+ inference_type = kwargs .get ("byInferenceType" , None )
178+ )
179+ }
0 commit comments