Skip to content

Commit 911dfe2

Browse files
authored
models: fix Application Inference Profiles mapping (#175)
* models: fix Application Inference Profiles mapping to include all profiles per model_id; switch to defaultdict(set) and emit all AIPs * Fix rebase issue --------- Co-authored-by: Jeremy Brockett <[email protected]>
1 parent a2110ff commit 911dfe2

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/api/models/bedrock.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
import time
66
from abc import ABC
7+
from collections import defaultdict
78
from typing import AsyncIterable, Iterable, Literal
89

910
import boto3
@@ -103,7 +104,8 @@ def list_bedrock_models() -> dict:
103104
model_list = {}
104105
try:
105106
profile_list = []
106-
app_profile_dict = {}
107+
# Map foundation model_id -> set of application inference profile ARNs
108+
app_profiles_by_model = defaultdict(set)
107109

108110
if ENABLE_CROSS_REGION_INFERENCE:
109111
# List system defined inference profile IDs
@@ -128,7 +130,7 @@ def list_bedrock_models() -> dict:
128130
if model_arn:
129131
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
130132
if model_id:
131-
app_profile_dict[model_id] = profile_arn
133+
app_profiles_by_model[model_id].add(profile_arn)
132134
except Exception as e:
133135
logger.warning(f"Error processing application profile: {e}")
134136
continue
@@ -156,9 +158,10 @@ def list_bedrock_models() -> dict:
156158
if profile_id in profile_list:
157159
model_list[profile_id] = {"modalities": input_modalities}
158160

159-
# Add application inference profiles
160-
if model_id in app_profile_dict:
161-
model_list[app_profile_dict[model_id]] = {"modalities": input_modalities}
161+
# Add application inference profiles (emit all profiles for this model)
162+
if model_id in app_profiles_by_model:
163+
for profile_arn in app_profiles_by_model[model_id]:
164+
model_list[profile_arn] = {"modalities": input_modalities}
162165

163166
except Exception as e:
164167
logger.error(f"Unable to list models: {str(e)}")

0 commit comments

Comments
 (0)