Skip to content

Commit d86e64e

Browse files
committed
refactor(bedrock): unify inference profile metadata handling and cleanup
- Add unified profile_metadata dictionary for both SYSTEM_DEFINED and APPLICATION inference profiles - Remove unused region prefix functions and defaultdict import - Add TEMPERATURE_TOPP_CONFLICT_MODELS set for Claude model parameter conflicts - Improve model ARN parsing and error handling in profile enumeration - Consolidate profile metadata storage to enable consistent feature detection
1 parent b4800c5 commit d86e64e

File tree

2 files changed

+138
-59
lines changed

2 files changed

+138
-59
lines changed

src/api/models/bedrock.py

Lines changed: 136 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import re
55
import time
66
from abc import ABC
7-
from collections import defaultdict
87
from typing import AsyncIterable, Iterable, Literal
98

109
import boto3
@@ -74,16 +73,6 @@
7473
config=config,
7574
)
7675

77-
78-
def get_inference_region_prefix():
79-
if AWS_REGION.startswith("ap-"):
80-
return "apac"
81-
return AWS_REGION[:2]
82-
83-
84-
# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
85-
cr_inference_prefix = get_inference_region_prefix()
86-
8776
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
8877
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
8978
"cohere.embed-english-v3": "Cohere Embed English",
@@ -95,6 +84,18 @@ def get_inference_region_prefix():
9584

9685
ENCODER = tiktoken.get_encoding("cl100k_base")
9786

87+
# Global mapping: Profile ID/ARN → Foundation Model ID
88+
# Handles both SYSTEM_DEFINED (cross-region) and APPLICATION profiles
89+
# This enables feature detection for all profile types without pattern matching
90+
profile_metadata = {}
91+
92+
# Models that don't support both temperature and topP simultaneously
93+
# When both are provided, temperature takes precedence and topP is removed
94+
TEMPERATURE_TOPP_CONFLICT_MODELS = {
95+
"claude-sonnet-4-5",
96+
"claude-haiku-4-5",
97+
}
98+
9899

99100
def list_bedrock_models() -> dict:
100101
"""Automatically getting a list of supported models.
@@ -106,15 +107,26 @@ def list_bedrock_models() -> dict:
106107
"""
107108
model_list = {}
108109
try:
109-
profile_list = []
110-
# Map foundation model_id -> set of application inference profile ARNs
111-
app_profiles_by_model = defaultdict(set)
112-
113110
if ENABLE_CROSS_REGION_INFERENCE:
114-
# List system defined inference profile IDs
111+
# List system defined inference profile IDs and store underlying model mapping
115112
paginator = bedrock_client.get_paginator('list_inference_profiles')
116113
for page in paginator.paginate(maxResults=1000, typeEquals="SYSTEM_DEFINED"):
117-
profile_list.extend([p["inferenceProfileId"] for p in page["inferenceProfileSummaries"]])
114+
for profile in page["inferenceProfileSummaries"]:
115+
profile_id = profile.get("inferenceProfileId")
116+
if not profile_id:
117+
continue
118+
119+
# Extract underlying model from first model in the profile
120+
models = profile.get("models", [])
121+
if models:
122+
model_arn = models[0].get("modelArn", "")
123+
if model_arn:
124+
# Extract foundation model ID from ARN
125+
model_id = model_arn.split('/')[-1]
126+
profile_metadata[profile_id] = {
127+
"underlying_model_id": model_id,
128+
"profile_type": "SYSTEM_DEFINED",
129+
}
118130

119131
if ENABLE_APPLICATION_INFERENCE_PROFILES:
120132
# List application defined inference profile IDs and create mapping
@@ -125,15 +137,28 @@ def list_bedrock_models() -> dict:
125137
profile_arn = profile.get("inferenceProfileArn")
126138
if not profile_arn:
127139
continue
128-
140+
129141
# Process all models in the profile
130142
models = profile.get("models", [])
131-
for model in models:
132-
model_arn = model.get("modelArn", "")
133-
if model_arn:
134-
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
135-
if model_id:
136-
app_profiles_by_model[model_id].add(profile_arn)
143+
if not models:
144+
logger.warning(f"Application profile {profile_arn} has no models")
145+
continue
146+
147+
# Take first model - all models in array are same type (regional instances)
148+
first_model = models[0]
149+
model_arn = first_model.get("modelArn", "")
150+
if not model_arn:
151+
continue
152+
153+
# Extract model ID from ARN (works for both foundation models and cross-region profiles)
154+
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
155+
156+
# Store in unified profile metadata for feature detection
157+
profile_metadata[profile_arn] = {
158+
"underlying_model_id": model_id,
159+
"profile_type": "APPLICATION",
160+
"profile_name": profile.get("inferenceProfileName", ""),
161+
}
137162
except Exception as e:
138163
logger.warning(f"Error processing application profile: {e}")
139164
continue
@@ -156,20 +181,10 @@ def list_bedrock_models() -> dict:
156181
if "ON_DEMAND" in inference_types:
157182
model_list[model_id] = {"modalities": input_modalities}
158183

159-
# Add cross-region inference model list.
160-
profile_id = cr_inference_prefix + "." + model_id
161-
if profile_id in profile_list:
162-
model_list[profile_id] = {"modalities": input_modalities}
163-
164-
# Add global cross-region inference profiles
165-
global_profile_id = "global." + model_id
166-
if global_profile_id in profile_list:
167-
model_list[global_profile_id] = {"modalities": input_modalities}
168-
169-
# Add application inference profiles (emit all profiles for this model)
170-
if model_id in app_profiles_by_model:
171-
for profile_arn in app_profiles_by_model[model_id]:
172-
model_list[profile_arn] = {"modalities": input_modalities}
184+
# Add all inference profiles (cross-region and application) for this model
185+
for profile_id, metadata in profile_metadata.items():
186+
if metadata.get("underlying_model_id") == model_id:
187+
model_list[profile_id] = {"modalities": input_modalities}
173188

174189
except Exception as e:
175190
logger.error(f"Unable to list models: {str(e)}")
@@ -197,17 +212,56 @@ def validate(self, chat_request: ChatRequest):
197212
error = ""
198213
# check if model is supported
199214
if chat_request.model not in bedrock_model_list.keys():
200-
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
215+
# Provide helpful error for application profiles
216+
if "application-inference-profile" in chat_request.model:
217+
error = (
218+
f"Application profile {chat_request.model} not found. "
219+
f"Available profiles can be listed via GET /models API. "
220+
f"Ensure ENABLE_APPLICATION_INFERENCE_PROFILES=true and "
221+
f"the profile exists in your AWS account."
222+
)
223+
else:
224+
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
201225
logger.error("Unsupported model: %s", chat_request.model)
202226

227+
# Validate profile has resolvable underlying model
228+
if not error and chat_request.model in profile_metadata:
229+
resolved = self._resolve_to_foundation_model(chat_request.model)
230+
if resolved == chat_request.model:
231+
logger.warning(
232+
f"Could not resolve profile {chat_request.model} "
233+
f"to underlying model. Some features may not work correctly."
234+
)
235+
203236
if error:
204237
raise HTTPException(
205238
status_code=400,
206239
detail=error,
207240
)
208241

209-
@staticmethod
210-
def _supports_prompt_caching(model_id: str) -> bool:
242+
def _resolve_to_foundation_model(self, model_id: str) -> str:
243+
"""
244+
Resolve any model identifier to foundation model ID for feature detection.
245+
246+
Handles:
247+
- Cross-region profiles (us.*, eu.*, apac.*, global.*)
248+
- Application profiles (arn:aws:bedrock:...:application-inference-profile/...)
249+
- Foundation models (pass through unchanged)
250+
251+
No pattern matching needed - just dictionary lookup.
252+
Unknown identifiers pass through unchanged (graceful fallback).
253+
254+
Args:
255+
model_id: Can be foundation model ID, cross-region profile, or app profile ARN
256+
257+
Returns:
258+
Foundation model ID if mapping exists, otherwise original model_id
259+
"""
260+
if model_id in profile_metadata:
261+
return profile_metadata[model_id]["underlying_model_id"]
262+
return model_id
263+
264+
def _supports_prompt_caching(self, model_id: str) -> bool:
211265
"""
212266
Check if model supports prompt caching based on model ID pattern.
213267
@@ -221,27 +275,28 @@ def _supports_prompt_caching(model_id: str) -> bool:
221275
Returns:
222276
bool: True if model supports prompt caching
223277
"""
224-
model_lower = model_id.lower()
278+
# Resolve profile to underlying model for feature detection
279+
resolved_model = self._resolve_to_foundation_model(model_id)
280+
model_lower = resolved_model.lower()
225281

226282
# Claude models pattern matching
227-
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
283+
if "anthropic.claude" in model_lower:
228284
# Exclude very old models that don't support caching
229285
excluded_patterns = ["claude-instant", "claude-v1", "claude-v2"]
230286
if any(pattern in model_lower for pattern in excluded_patterns):
231287
return False
232288
return True
233289

234290
# Nova models pattern matching
235-
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
291+
if "amazon.nova" in model_lower:
236292
return True
237293

238294
# Future providers can be added here
239295
# Example: if "provider.model-name" in model_lower: return True
240296

241297
return False
242298

243-
@staticmethod
244-
def _get_max_cache_tokens(model_id: str) -> int | None:
299+
def _get_max_cache_tokens(self, model_id: str) -> int | None:
245300
"""
246301
Get maximum cacheable tokens limit for the model.
247302
@@ -252,14 +307,16 @@ def _get_max_cache_tokens(model_id: str) -> int | None:
252307
Returns:
253308
int | None: Max tokens, or None if unlimited
254309
"""
255-
model_lower = model_id.lower()
310+
# Resolve profile to underlying model for feature detection
311+
resolved_model = self._resolve_to_foundation_model(model_id)
312+
model_lower = resolved_model.lower()
256313

257314
# Nova models have 20K limit
258-
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
315+
if "amazon.nova" in model_lower:
259316
return 20_000
260317

261318
# Claude: No explicit limit
262-
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
319+
if "anthropic.claude" in model_lower:
263320
return None
264321

265322
return None
@@ -269,6 +326,14 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
269326
if DEBUG:
270327
logger.info("Raw request: " + chat_request.model_dump_json())
271328

329+
# Log profile resolution for debugging
330+
if chat_request.model in profile_metadata:
331+
resolved = self._resolve_to_foundation_model(chat_request.model)
332+
profile_type = profile_metadata[chat_request.model].get("profile_type", "UNKNOWN")
333+
logger.info(
334+
f"Profile resolution: {chat_request.model} ({profile_type}) → {resolved}"
335+
)
336+
272337
# convert OpenAI chat request to Bedrock SDK request
273338
args = self._parse_request(chat_request)
274339
if DEBUG:
@@ -667,15 +732,27 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
667732

668733
# Base inference parameters.
669734
inference_config = {
670-
"temperature": chat_request.temperature,
671735
"maxTokens": chat_request.max_tokens,
672-
"topP": chat_request.top_p,
673736
}
674737

675-
# Claude Sonnet 4.5 doesn't support both temperature and topP
676-
# Remove topP for this model
677-
if "claude-sonnet-4-5" in chat_request.model.lower():
678-
inference_config.pop("topP", None)
738+
# Only include optional parameters when specified
739+
if chat_request.temperature is not None:
740+
inference_config["temperature"] = chat_request.temperature
741+
if chat_request.top_p is not None:
742+
inference_config["topP"] = chat_request.top_p
743+
744+
# Some models (Claude Sonnet 4.5, Haiku 4.5) don't support both temperature and topP
745+
# When both are provided, keep temperature and remove topP
746+
# Resolve profile to underlying model for feature detection
747+
resolved_model = self._resolve_to_foundation_model(chat_request.model)
748+
model_lower = resolved_model.lower()
749+
750+
# Check if model is in the conflict list and both parameters are present
751+
if "temperature" in inference_config and "topP" in inference_config:
752+
if any(conflict_model in model_lower for conflict_model in TEMPERATURE_TOPP_CONFLICT_MODELS):
753+
inference_config.pop("topP", None)
754+
if DEBUG:
755+
logger.info(f"Removed topP for {chat_request.model} (conflicts with temperature)")
679756

680757
if chat_request.stop is not None:
681758
stop = chat_request.stop
@@ -692,9 +769,11 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
692769
if chat_request.reasoning_effort:
693770
# reasoning_effort is supported by Claude and DeepSeek v3
694771
# Different models use different formats
695-
model_lower = chat_request.model.lower()
772+
# Resolve profile to underlying model for feature detection
773+
resolved_model = self._resolve_to_foundation_model(chat_request.model)
774+
model_lower = resolved_model.lower()
696775

697-
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
776+
if "anthropic.claude" in model_lower:
698777
# Claude format: reasoning_config = object with budget_tokens
699778
max_tokens = (
700779
chat_request.max_completion_tokens

src/api/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ class ChatRequest(BaseModel):
9797
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
9898
stream: bool | None = False
9999
stream_options: StreamOptions | None = None
100-
temperature: float | None = Field(default=1.0, le=2.0, ge=0.0)
101-
top_p: float | None = Field(default=1.0, le=1.0, ge=0.0)
100+
temperature: float | None = Field(default=None, le=2.0, ge=0.0)
101+
top_p: float | None = Field(default=None, le=1.0, ge=0.0)
102102
user: str | None = None # Not used
103103
max_tokens: int | None = 2048
104104
max_completion_tokens: int | None = None

0 commit comments

Comments
 (0)