Skip to content

Commit c2d551a

Browse files
author
Sean Smith
committed
Support imported models
Signed-off-by: Sean Smith <[email protected]>
1 parent c3a30f9 commit c2d551a

File tree

4 files changed

+30
-1
lines changed

4 files changed

+30
-1
lines changed

deployment/BedrockProxy.template renamed to deployment/BedrockProxy.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ Parameters:
55
Type: String
66
Default: ""
77
Description: The parameter name in System Manager used to store the API Key, leave blank to use a default key
8+
EnableImportedModels:
9+
Type: String
10+
Default: false
11+
AllowedValues:
12+
- true
13+
- false
14+
Description: If enabled, models imported into Bedrock will be available to use.
815
Resources:
916
VPCB9E5F0B4:
1017
Type: AWS::EC2::VPC
@@ -197,6 +204,7 @@ Resources:
197204
- DefaultValue: anthropic.claude-3-sonnet-20240229-v1:0
198205
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
199206
ENABLE_CROSS_REGION_INFERENCE: "true"
207+
ENABLE_IMPORTED_MODELS: !Ref EnableImportedModels
200208
MemorySize: 1024
201209
PackageType: Image
202210
Role:

deployment/BedrockProxyFargate.template renamed to deployment/BedrockProxyFargate.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ Parameters:
55
Type: String
66
Default: ""
77
Description: The parameter name in System Manager used to store the API Key, leave blank to use a default key
8+
EnableImportedModels:
9+
Type: String
10+
Default: false
11+
AllowedValues:
12+
- true
13+
- false
14+
Description: If enabled, models imported into Bedrock will be available to use.
815
Resources:
916
VPCB9E5F0B4:
1017
Type: AWS::EC2::VPC
@@ -237,6 +244,8 @@ Resources:
237244
Value: cohere.embed-multilingual-v3
238245
- Name: ENABLE_CROSS_REGION_INFERENCE
239246
Value: "true"
247+
- Name: ENABLE_IMPORTED_MODELS
248+
Value: !Ref EnableImportedModels
240249
Essential: true
241250
Image:
242251
Fn::Join:

src/api/models/bedrock.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
Embedding,
3939

4040
)
41-
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL
41+
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL, ENABLE_IMPORTED_MODELS
4242

4343
logger = logging.getLogger(__name__)
4444

@@ -99,6 +99,17 @@ def list_bedrock_models() -> dict:
9999
byOutputModality='TEXT'
100100
)
101101

102+
# Add imported models to the list if ENABLE_IMPORTED_MODELS is true
103+
if ENABLE_IMPORTED_MODELS:
104+
response_imported = bedrock_client.list_imported_models()
105+
106+
# Add imported models to the default model list
107+
for model in response_imported['modelSummaries']:
108+
model_id = model.get('modelArn')
109+
model_list[model_id] = {
110+
'modalities': ["TEXT"]
111+
}
112+
102113
for model in response['modelSummaries']:
103114
model_id = model.get('modelId', 'N/A')
104115
stream_supported = model.get('responseStreamingSupported', True)

src/api/setting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
"DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3"
2121
)
2222
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
23+
ENABLE_IMPORTED_MODELS = os.environ.get("ENABLE_IMPORTED_MODELS", "true").lower() != "false"

0 commit comments

Comments
 (0)