44import re
55import time
66from abc import ABC
7- from collections import defaultdict
87from typing import AsyncIterable , Iterable , Literal
98
109import boto3
7473 config = config ,
7574)
7675
77-
78- def get_inference_region_prefix ():
79- if AWS_REGION .startswith ("ap-" ):
80- return "apac"
81- return AWS_REGION [:2 ]
82-
83-
84- # https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
85- cr_inference_prefix = get_inference_region_prefix ()
86-
8776SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
8877 "cohere.embed-multilingual-v3" : "Cohere Embed Multilingual" ,
8978 "cohere.embed-english-v3" : "Cohere Embed English" ,
@@ -95,6 +84,18 @@ def get_inference_region_prefix():
9584
9685ENCODER = tiktoken .get_encoding ("cl100k_base" )
9786
87+ # Global mapping: Profile ID/ARN → Foundation Model ID
88+ # Handles both SYSTEM_DEFINED (cross-region) and APPLICATION profiles
89+ # This enables feature detection for all profile types without pattern matching
90+ profile_metadata = {}
91+
92+ # Models that don't support both temperature and topP simultaneously
93+ # When both are provided, temperature takes precedence and topP is removed
94+ TEMPERATURE_TOPP_CONFLICT_MODELS = {
95+ "claude-sonnet-4-5" ,
96+ "claude-haiku-4-5" ,
97+ }
98+
9899
99100def list_bedrock_models () -> dict :
100101 """Automatically getting a list of supported models.
@@ -106,15 +107,26 @@ def list_bedrock_models() -> dict:
106107 """
107108 model_list = {}
108109 try :
109- profile_list = []
110- # Map foundation model_id -> set of application inference profile ARNs
111- app_profiles_by_model = defaultdict (set )
112-
113110 if ENABLE_CROSS_REGION_INFERENCE :
114- # List system defined inference profile IDs
111+ # List system defined inference profile IDs and store underlying model mapping
115112 paginator = bedrock_client .get_paginator ('list_inference_profiles' )
116113 for page in paginator .paginate (maxResults = 1000 , typeEquals = "SYSTEM_DEFINED" ):
117- profile_list .extend ([p ["inferenceProfileId" ] for p in page ["inferenceProfileSummaries" ]])
114+ for profile in page ["inferenceProfileSummaries" ]:
115+ profile_id = profile .get ("inferenceProfileId" )
116+ if not profile_id :
117+ continue
118+
119+ # Extract underlying model from first model in the profile
120+ models = profile .get ("models" , [])
121+ if models :
122+ model_arn = models [0 ].get ("modelArn" , "" )
123+ if model_arn :
124+ # Extract foundation model ID from ARN
125+ model_id = model_arn .split ('/' )[- 1 ]
126+ profile_metadata [profile_id ] = {
127+ "underlying_model_id" : model_id ,
128+ "profile_type" : "SYSTEM_DEFINED" ,
129+ }
118130
119131 if ENABLE_APPLICATION_INFERENCE_PROFILES :
120132 # List application defined inference profile IDs and create mapping
@@ -125,15 +137,28 @@ def list_bedrock_models() -> dict:
125137 profile_arn = profile .get ("inferenceProfileArn" )
126138 if not profile_arn :
127139 continue
128-
140+
129141 # Process all models in the profile
130142 models = profile .get ("models" , [])
131- for model in models :
132- model_arn = model .get ("modelArn" , "" )
133- if model_arn :
134- model_id = model_arn .split ('/' )[- 1 ] if '/' in model_arn else model_arn
135- if model_id :
136- app_profiles_by_model [model_id ].add (profile_arn )
143+ if not models :
144+ logger .warning (f"Application profile { profile_arn } has no models" )
145+ continue
146+
147+ # Take first model - all models in array are same type (regional instances)
148+ first_model = models [0 ]
149+ model_arn = first_model .get ("modelArn" , "" )
150+ if not model_arn :
151+ continue
152+
153+ # Extract model ID from ARN (works for both foundation models and cross-region profiles)
154+ model_id = model_arn .split ('/' )[- 1 ] if '/' in model_arn else model_arn
155+
156+ # Store in unified profile metadata for feature detection
157+ profile_metadata [profile_arn ] = {
158+ "underlying_model_id" : model_id ,
159+ "profile_type" : "APPLICATION" ,
160+ "profile_name" : profile .get ("inferenceProfileName" , "" ),
161+ }
137162 except Exception as e :
138163 logger .warning (f"Error processing application profile: { e } " )
139164 continue
@@ -156,20 +181,10 @@ def list_bedrock_models() -> dict:
156181 if "ON_DEMAND" in inference_types :
157182 model_list [model_id ] = {"modalities" : input_modalities }
158183
159- # Add cross-region inference model list.
160- profile_id = cr_inference_prefix + "." + model_id
161- if profile_id in profile_list :
162- model_list [profile_id ] = {"modalities" : input_modalities }
163-
164- # Add global cross-region inference profiles
165- global_profile_id = "global." + model_id
166- if global_profile_id in profile_list :
167- model_list [global_profile_id ] = {"modalities" : input_modalities }
168-
169- # Add application inference profiles (emit all profiles for this model)
170- if model_id in app_profiles_by_model :
171- for profile_arn in app_profiles_by_model [model_id ]:
172- model_list [profile_arn ] = {"modalities" : input_modalities }
184+ # Add all inference profiles (cross-region and application) for this model
185+ for profile_id , metadata in profile_metadata .items ():
186+ if metadata .get ("underlying_model_id" ) == model_id :
187+ model_list [profile_id ] = {"modalities" : input_modalities }
173188
174189 except Exception as e :
175190 logger .error (f"Unable to list models: { str (e )} " )
@@ -197,17 +212,56 @@ def validate(self, chat_request: ChatRequest):
197212 error = ""
198213 # check if model is supported
199214 if chat_request .model not in bedrock_model_list .keys ():
200- error = f"Unsupported model { chat_request .model } , please use models API to get a list of supported models"
215+ # Provide helpful error for application profiles
216+ if "application-inference-profile" in chat_request .model :
217+ error = (
218+ f"Application profile { chat_request .model } not found. "
219+ f"Available profiles can be listed via GET /models API. "
220+ f"Ensure ENABLE_APPLICATION_INFERENCE_PROFILES=true and "
221+ f"the profile exists in your AWS account."
222+ )
223+ else :
224+ error = f"Unsupported model { chat_request .model } , please use models API to get a list of supported models"
201225 logger .error ("Unsupported model: %s" , chat_request .model )
202226
227+ # Validate profile has resolvable underlying model
228+ if not error and chat_request .model in profile_metadata :
229+ resolved = self ._resolve_to_foundation_model (chat_request .model )
230+ if resolved == chat_request .model :
231+ logger .warning (
232+ f"Could not resolve profile { chat_request .model } "
233+ f"to underlying model. Some features may not work correctly."
234+ )
235+
203236 if error :
204237 raise HTTPException (
205238 status_code = 400 ,
206239 detail = error ,
207240 )
208241
209- @staticmethod
210- def _supports_prompt_caching (model_id : str ) -> bool :
242+ def _resolve_to_foundation_model (self , model_id : str ) -> str :
243+ """
244+ Resolve any model identifier to foundation model ID for feature detection.
245+
246+ Handles:
247+ - Cross-region profiles (us.*, eu.*, apac.*, global.*)
248+ - Application profiles (arn:aws:bedrock:...:application-inference-profile/...)
249+ - Foundation models (pass through unchanged)
250+
251+ No pattern matching needed - just dictionary lookup.
252+ Unknown identifiers pass through unchanged (graceful fallback).
253+
254+ Args:
255+ model_id: Can be foundation model ID, cross-region profile, or app profile ARN
256+
257+ Returns:
258+ Foundation model ID if mapping exists, otherwise original model_id
259+ """
260+ if model_id in profile_metadata :
261+ return profile_metadata [model_id ]["underlying_model_id" ]
262+ return model_id
263+
264+ def _supports_prompt_caching (self , model_id : str ) -> bool :
211265 """
212266 Check if model supports prompt caching based on model ID pattern.
213267
@@ -221,27 +275,28 @@ def _supports_prompt_caching(model_id: str) -> bool:
221275 Returns:
222276 bool: True if model supports prompt caching
223277 """
224- model_lower = model_id .lower ()
278+ # Resolve profile to underlying model for feature detection
279+ resolved_model = self ._resolve_to_foundation_model (model_id )
280+ model_lower = resolved_model .lower ()
225281
226282 # Claude models pattern matching
227- if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower :
283+ if "anthropic.claude" in model_lower :
228284 # Exclude very old models that don't support caching
229285 excluded_patterns = ["claude-instant" , "claude-v1" , "claude-v2" ]
230286 if any (pattern in model_lower for pattern in excluded_patterns ):
231287 return False
232288 return True
233289
234290 # Nova models pattern matching
235- if "amazon.nova" in model_lower or ".amazon.nova" in model_lower :
291+ if "amazon.nova" in model_lower :
236292 return True
237293
238294 # Future providers can be added here
239295 # Example: if "provider.model-name" in model_lower: return True
240296
241297 return False
242298
243- @staticmethod
244- def _get_max_cache_tokens (model_id : str ) -> int | None :
299+ def _get_max_cache_tokens (self , model_id : str ) -> int | None :
245300 """
246301 Get maximum cacheable tokens limit for the model.
247302
@@ -252,14 +307,16 @@ def _get_max_cache_tokens(model_id: str) -> int | None:
252307 Returns:
253308 int | None: Max tokens, or None if unlimited
254309 """
255- model_lower = model_id .lower ()
310+ # Resolve profile to underlying model for feature detection
311+ resolved_model = self ._resolve_to_foundation_model (model_id )
312+ model_lower = resolved_model .lower ()
256313
257314 # Nova models have 20K limit
258- if "amazon.nova" in model_lower or ".amazon.nova" in model_lower :
315+ if "amazon.nova" in model_lower :
259316 return 20_000
260317
261318 # Claude: No explicit limit
262- if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower :
319+ if "anthropic.claude" in model_lower :
263320 return None
264321
265322 return None
@@ -269,6 +326,14 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
269326 if DEBUG :
270327 logger .info ("Raw request: " + chat_request .model_dump_json ())
271328
329+ # Log profile resolution for debugging
330+ if chat_request .model in profile_metadata :
331+ resolved = self ._resolve_to_foundation_model (chat_request .model )
332+ profile_type = profile_metadata [chat_request .model ].get ("profile_type" , "UNKNOWN" )
333+ logger .info (
334+ f"Profile resolution: { chat_request .model } ({ profile_type } ) → { resolved } "
335+ )
336+
272337 # convert OpenAI chat request to Bedrock SDK request
273338 args = self ._parse_request (chat_request )
274339 if DEBUG :
@@ -667,15 +732,27 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
667732
668733 # Base inference parameters.
669734 inference_config = {
670- "temperature" : chat_request .temperature ,
671735 "maxTokens" : chat_request .max_tokens ,
672- "topP" : chat_request .top_p ,
673736 }
674737
675- # Claude Sonnet 4.5 doesn't support both temperature and topP
676- # Remove topP for this model
677- if "claude-sonnet-4-5" in chat_request .model .lower ():
678- inference_config .pop ("topP" , None )
738+ # Only include optional parameters when specified
739+ if chat_request .temperature is not None :
740+ inference_config ["temperature" ] = chat_request .temperature
741+ if chat_request .top_p is not None :
742+ inference_config ["topP" ] = chat_request .top_p
743+
744+ # Some models (Claude Sonnet 4.5, Haiku 4.5) don't support both temperature and topP
745+ # When both are provided, keep temperature and remove topP
746+ # Resolve profile to underlying model for feature detection
747+ resolved_model = self ._resolve_to_foundation_model (chat_request .model )
748+ model_lower = resolved_model .lower ()
749+
750+ # Check if model is in the conflict list and both parameters are present
751+ if "temperature" in inference_config and "topP" in inference_config :
752+ if any (conflict_model in model_lower for conflict_model in TEMPERATURE_TOPP_CONFLICT_MODELS ):
753+ inference_config .pop ("topP" , None )
754+ if DEBUG :
755+ logger .info (f"Removed topP for { chat_request .model } (conflicts with temperature)" )
679756
680757 if chat_request .stop is not None :
681758 stop = chat_request .stop
@@ -692,9 +769,11 @@ def _parse_request(self, chat_request: ChatRequest) -> dict:
692769 if chat_request .reasoning_effort :
693770 # reasoning_effort is supported by Claude and DeepSeek v3
694771 # Different models use different formats
695- model_lower = chat_request .model .lower ()
772+ # Resolve profile to underlying model for feature detection
773+ resolved_model = self ._resolve_to_foundation_model (chat_request .model )
774+ model_lower = resolved_model .lower ()
696775
697- if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower :
776+ if "anthropic.claude" in model_lower :
698777 # Claude format: reasoning_config = object with budget_tokens
699778 max_tokens = (
700779 chat_request .max_completion_tokens
0 commit comments