Skip to content

Commit 0a22005

Browse files
ybalbert001Yuanbo Ligemini-code-assist[bot]
authored
FixBug: Inference profile isn't working for bedrock embedding model (#2144)
* FixBug: Inference profile isn't working for bedrock embedding model * Update models/bedrock/models/text_embedding/text_embedding.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Fix logic issue --------- Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent a11dc8b commit 0a22005

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

models/bedrock/manifest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version: 0.0.52
1+
version: 0.0.53
22
type: plugin
33
author: langgenius
44
name: bedrock

models/bedrock/models/text_embedding/text_embedding.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def _invoke(
6262
if inference_profile_id:
6363
# Get the full ARN from the profile ID
6464
profile_info = get_inference_profile_info(inference_profile_id, credentials)
65-
model_id = profile_info.get("inferenceProfileArn")
66-
if not model_id:
65+
model_package_arn = profile_info.get("inferenceProfileArn")
66+
if not model_package_arn:
6767
raise InvokeError(f"Could not get ARN for inference profile {inference_profile_id}")
68-
logger.info(f"Using inference profile ARN: {model_id}")
68+
logger.info(f"Using inference profile ARN: {model_package_arn}")
6969

7070
# Determine model prefix from underlying models
7171
underlying_models = profile_info.get("models", [])
@@ -80,6 +80,7 @@ def _invoke(
8080
raise InvokeError(f"No underlying models found in inference profile")
8181
else:
8282
# Traditional model - use model directly
83+
model_package_arn = model
8384
model_prefix = model.split(".")[0]
8485

8586
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
@@ -102,7 +103,7 @@ def _invoke(
102103
}
103104
}
104105
}
105-
response_body = self._invoke_bedrock_embedding(model_id, bedrock_runtime, body)
106+
response_body = self._invoke_bedrock_embedding(model_package_arn, bedrock_runtime, body)
106107
embedding_data = response_body.get("embeddings", [{}])[0]
107108
embeddings.extend([embedding_data.get("embedding")])
108109
token_usage += len(text.split())
@@ -120,7 +121,7 @@ def _invoke(
120121
body = {
121122
"inputText": text,
122123
}
123-
response_body = self._invoke_bedrock_embedding(model_id, bedrock_runtime, body)
124+
response_body = self._invoke_bedrock_embedding(model_package_arn, bedrock_runtime, body)
124125
embeddings.extend([response_body.get("embedding")])
125126
token_usage += response_body.get("inputTextTokenCount")
126127
logger.warning(f"Total Tokens: {token_usage}")
@@ -138,7 +139,7 @@ def _invoke(
138139
"texts": [text],
139140
"input_type": input_type,
140141
}
141-
response_body = self._invoke_bedrock_embedding(model_id, bedrock_runtime, body)
142+
response_body = self._invoke_bedrock_embedding(model_package_arn, bedrock_runtime, body)
142143
embeddings.extend(response_body.get("embeddings"))
143144
token_usage += len(text)
144145
result = TextEmbeddingResult(

0 commit comments

Comments
 (0)