Skip to content

Commit fdfabd6

Browse files
skamenan7iamemilio
authored andcommitted
fix: AWS Bedrock inference profile ID conversion for region-specific endpoints (llamastack#3386)
Fixes llamastack#3370 AWS switched to requiring region-prefixed inference profile IDs instead of foundation model IDs for on-demand throughput. This was causing ValidationException errors. Added auto-detection based on boto3 client region to convert model IDs like meta.llama3-1-70b-instruct-v1:0 to us.meta.llama3-1-70b-instruct-v1:0 depending on the detected region. Also handles edge cases like ARNs, case insensitive regions, and None regions. Tested with this request. ```json { "model_id": "meta.llama3-1-8b-instruct-v1:0", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "tell me a riddle" } ], "sampling_params": { "strategy": { "type": "top_p", "temperature": 0.7, "top_p": 0.9 }, "max_tokens": 512 } } ``` <img width="1488" height="878" alt="image" src="https://github.com/user-attachments/assets/0d61beec-3869-4a31-8f37-9f554c280b88" />
1 parent abd3ce4 commit fdfabd6

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,43 @@
5353

5454
from .models import MODEL_ENTRIES
5555

56+
REGION_PREFIX_MAP = {
57+
"us": "us.",
58+
"eu": "eu.",
59+
"ap": "ap.",
60+
}
61+
62+
63+
def _get_region_prefix(region: str | None) -> str:
64+
# AWS requires region prefixes for inference profiles
65+
if region is None:
66+
return "us." # default to US when we don't know
67+
68+
# Handle case insensitive region matching
69+
region_lower = region.lower()
70+
for prefix in REGION_PREFIX_MAP:
71+
if region_lower.startswith(f"{prefix}-"):
72+
return REGION_PREFIX_MAP[prefix]
73+
74+
# Fallback to US for anything we don't recognize
75+
return "us."
76+
77+
78+
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
79+
# Return ARNs unchanged
80+
if model_id.startswith("arn:"):
81+
return model_id
82+
83+
# Return inference profile IDs that already have regional prefixes
84+
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
85+
return model_id
86+
87+
# Default to US East when no region is provided
88+
if region is None:
89+
region = "us-east-1"
90+
91+
return _get_region_prefix(region) + model_id
92+
5693

5794
class BedrockInferenceAdapter(
5895
ModelRegistryHelper,
@@ -166,8 +203,13 @@ async def _get_params_for_chat_completion(self, request: ChatCompletionRequest)
166203
options["repetition_penalty"] = sampling_params.repetition_penalty
167204

168205
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
206+
207+
# Convert foundation model ID to inference profile ID
208+
region_name = self.client.meta.region_name
209+
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
210+
169211
return {
170-
"modelId": bedrock_model,
212+
"modelId": inference_profile_id,
171213
"body": json.dumps(
172214
{
173215
"prompt": prompt,
@@ -185,6 +227,11 @@ async def embeddings(
185227
task_type: EmbeddingTaskType | None = None,
186228
) -> EmbeddingsResponse:
187229
model = await self.model_store.get_model(model_id)
230+
231+
# Convert foundation model ID to inference profile ID
232+
region_name = self.client.meta.region_name
233+
inference_profile_id = _to_inference_profile_id(model.provider_resource_id, region_name)
234+
188235
embeddings = []
189236
for content in contents:
190237
assert not content_has_media(content), "Bedrock does not support media for embeddings"
@@ -193,7 +240,7 @@ async def embeddings(
193240
body = json.dumps(input_body)
194241
response = self.client.invoke_model(
195242
body=body,
196-
modelId=model.provider_resource_id,
243+
modelId=inference_profile_id,
197244
accept="application/json",
198245
contentType="application/json",
199246
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from llama_stack.providers.remote.inference.bedrock.bedrock import (
8+
_get_region_prefix,
9+
_to_inference_profile_id,
10+
)
11+
12+
13+
def test_region_prefixes():
14+
assert _get_region_prefix("us-east-1") == "us."
15+
assert _get_region_prefix("eu-west-1") == "eu."
16+
assert _get_region_prefix("ap-south-1") == "ap."
17+
assert _get_region_prefix("ca-central-1") == "us."
18+
19+
# Test case insensitive
20+
assert _get_region_prefix("US-EAST-1") == "us."
21+
assert _get_region_prefix("EU-WEST-1") == "eu."
22+
assert _get_region_prefix("Ap-South-1") == "ap."
23+
24+
# Test None region
25+
assert _get_region_prefix(None) == "us."
26+
27+
28+
def test_model_id_conversion():
29+
# Basic conversion
30+
assert (
31+
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
32+
)
33+
34+
# Already has prefix
35+
assert (
36+
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
37+
== "us.meta.llama3-1-70b-instruct-v1:0"
38+
)
39+
40+
# ARN should be returned unchanged
41+
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
42+
assert _to_inference_profile_id(arn, "us-east-1") == arn
43+
44+
# ARN should be returned unchanged even without region
45+
assert _to_inference_profile_id(arn) == arn
46+
47+
# Optional region parameter defaults to us-east-1
48+
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
49+
50+
# Different regions work with optional parameter
51+
assert (
52+
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
53+
)

0 commit comments

Comments
 (0)