@@ -62,10 +62,10 @@ def _invoke(
6262 if inference_profile_id :
6363 # Get the full ARN from the profile ID
6464 profile_info = get_inference_profile_info (inference_profile_id , credentials )
65- model_id = profile_info .get ("inferenceProfileArn" )
66- if not model_id :
65+ model_package_arn = profile_info .get ("inferenceProfileArn" )
66+ if not model_package_arn :
6767 raise InvokeError (f"Could not get ARN for inference profile { inference_profile_id } " )
68- logger .info (f"Using inference profile ARN: { model_id } " )
68+ logger .info (f"Using inference profile ARN: { model_package_arn } " )
6969
7070 # Determine model prefix from underlying models
7171 underlying_models = profile_info .get ("models" , [])
@@ -80,6 +80,7 @@ def _invoke(
8080 raise InvokeError (f"No underlying models found in inference profile" )
8181 else :
8282 # Traditional model - use model directly
83+ model_package_arn = model
8384 model_prefix = model .split ("." )[0 ]
8485
8586 bedrock_runtime = get_bedrock_client ("bedrock-runtime" , credentials )
@@ -102,7 +103,7 @@ def _invoke(
102103 }
103104 }
104105 }
105- response_body = self ._invoke_bedrock_embedding (model_id , bedrock_runtime , body )
106+ response_body = self ._invoke_bedrock_embedding (model_package_arn , bedrock_runtime , body )
106107 embedding_data = response_body .get ("embeddings" , [{}])[0 ]
107108 embeddings .extend ([embedding_data .get ("embedding" )])
108109 token_usage += len (text .split ())
@@ -120,7 +121,7 @@ def _invoke(
120121 body = {
121122 "inputText" : text ,
122123 }
123- response_body = self ._invoke_bedrock_embedding (model_id , bedrock_runtime , body )
124+ response_body = self ._invoke_bedrock_embedding (model_package_arn , bedrock_runtime , body )
124125 embeddings .extend ([response_body .get ("embedding" )])
125126 token_usage += response_body .get ("inputTextTokenCount" )
126127 logger .warning (f"Total Tokens: { token_usage } " )
@@ -138,7 +139,7 @@ def _invoke(
138139 "texts" : [text ],
139140 "input_type" : input_type ,
140141 }
141- response_body = self ._invoke_bedrock_embedding (model_id , bedrock_runtime , body )
142+ response_body = self ._invoke_bedrock_embedding (model_package_arn , bedrock_runtime , body )
142143 embeddings .extend (response_body .get ("embeddings" ))
143144 token_usage += len (text )
144145 result = TextEmbeddingResult (
0 commit comments