Skip to content

Commit 8336c32

Browse files
Python: Add bedrock model provider parameter (#12853)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Fixes #12833 ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> Add an optional setting to the Bedrock AI connectors to allow users to specify the Bedrock model providers (i.e. amazon, anthropic, etc) so that they can use application inference profiles. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: Dmytro Struk <[email protected]>
1 parent 1198b57 commit 8336c32

File tree

12 files changed

+370
-38
lines changed

12 files changed

+370
-38
lines changed

python/samples/concepts/setup/chat_completion_services.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def get_bedrock_chat_completion_service_and_request_settings() -> tuple[
213213
"""
214214
from semantic_kernel.connectors.ai.bedrock import BedrockChatCompletion, BedrockChatPromptExecutionSettings
215215

216-
chat_service = BedrockChatCompletion(service_id=service_id, model_id="anthropic.claude-3-sonnet-20240229-v1:0")
216+
chat_service = BedrockChatCompletion(service_id=service_id)
217217
request_settings = BedrockChatPromptExecutionSettings(
218218
# For model specific settings, specify them in the extension_data dictionary.
219219
# For example, for Cohere Command specific settings, refer to:

python/semantic_kernel/connectors/ai/bedrock/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ bedrock_chat_completion_service = BedrockChatCompletion(runtime_client=runtime_c
3838

3939
To find model supports by AWS regions, refer to this [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html).
4040

41+
### Inference profiles
42+
43+
You can create inference profiles in AWS Bedrock to monitor and optimize the performance of your foundation models. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles.html) for more information.
44+
45+
When you are using an Application Inference Profile, you must specify the `BEDROCK_MODEL_PROVIDER` environment variable to the model provider you are using. For example, if you are using Amazon Titan, you must set `BEDROCK_MODEL_PROVIDER=amazon`. This is because an Application Inference Profile doesn't contain the model provider information, and the Bedrock connector needs to know which model provider to use so that it can create the correct request body to the Bedrock API.
46+
47+
> An Application Inference Profile ARN is usually formatted as followed: `arn:aws:bedrock:<region>:<account-id>:application-inference-profile/<profile-id>`.
48+
4149
### Input & Output Modalities
4250

4351
Foundational models in Bedrock support the multiple modalities, including text, image, and embedding. However, not all models support the same modalities. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for more information.

python/semantic_kernel/connectors/ai/bedrock/bedrock_settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import ClassVar
44

5+
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider
56
from semantic_kernel.kernel_pydantic import KernelBaseSettings
67
from semantic_kernel.utils.feature_stage_decorator import experimental
78

@@ -25,10 +26,16 @@ class BedrockSettings(KernelBaseSettings):
2526
(Env var BEDROCK_TEXT_MODEL_ID)
2627
- embedding_model_id: str | None - The Amazon Bedrock embedding model ID to use.
2728
(Env var BEDROCK_EMBEDDING_MODEL_ID)
29+
- model_provider: BedrockModelProvider | None - The Bedrock model provider to use.
30+
If not provided, the model provider will be extracted from the model ID.
31+
When using an Application Inference Profile where the model provider is not part
32+
of the model ID, this setting must be provided.
33+
(Env var BEDROCK_MODEL_PROVIDER)
2834
"""
2935

3036
env_prefix: ClassVar[str] = "BEDROCK_"
3137

3238
chat_model_id: str | None = None
3339
text_model_id: str | None = None
3440
embedding_model_id: str | None = None
41+
model_provider: BedrockModelProvider | None = None

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import boto3
77

8+
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import BedrockModelProvider
89
from semantic_kernel.kernel_pydantic import KernelBaseModel
910

1011

@@ -19,22 +20,30 @@ class BedrockBase(KernelBaseModel, ABC):
1920
# Client: Use for model management
2021
bedrock_client: Any
2122

23+
bedrock_model_provider: BedrockModelProvider | None = None
24+
2225
def __init__(
2326
self,
2427
*,
2528
runtime_client: Any | None = None,
2629
client: Any | None = None,
30+
bedrock_model_provider: BedrockModelProvider | None = None,
2731
**kwargs: Any,
2832
) -> None:
2933
"""Initialize the Amazon Bedrock Base Class.
3034
3135
Args:
3236
runtime_client: The Amazon Bedrock runtime client to use.
3337
client: The Amazon Bedrock client to use.
38+
bedrock_model_provider: The Bedrock model provider to use.
39+
If not provided, the model provider will be extracted from the model ID.
40+
When using an Application Inference Profile where the model provider is not part
41+
of the model ID, this setting must be provided.
3442
**kwargs: Additional keyword arguments.
3543
"""
3644
super().__init__(
3745
bedrock_runtime_client=runtime_client or boto3.client("bedrock-runtime"),
3846
bedrock_client=client or boto3.client("bedrock"),
47+
bedrock_model_provider=bedrock_model_provider,
3948
**kwargs,
4049
)

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_chat_completion.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings
1717
from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase
1818
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import (
19+
BedrockModelProvider,
1920
get_chat_completion_additional_model_request_fields,
2021
)
2122
from semantic_kernel.connectors.ai.bedrock.services.model_provider.utils import (
@@ -36,10 +37,7 @@
3637
from semantic_kernel.contents.text_content import TextContent
3738
from semantic_kernel.contents.utils.author_role import AuthorRole
3839
from semantic_kernel.contents.utils.finish_reason import FinishReason
39-
from semantic_kernel.exceptions.service_exceptions import (
40-
ServiceInitializationError,
41-
ServiceInvalidResponseError,
42-
)
40+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidResponseError
4341
from semantic_kernel.utils.async_utils import run_in_executor
4442
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import (
4543
trace_chat_completion,
@@ -60,6 +58,7 @@ class BedrockChatCompletion(BedrockBase, ChatCompletionClientBase):
6058
def __init__(
6159
self,
6260
model_id: str | None = None,
61+
model_provider: BedrockModelProvider | None = None,
6362
service_id: str | None = None,
6463
runtime_client: Any | None = None,
6564
client: Any | None = None,
@@ -70,6 +69,7 @@ def __init__(
7069
7170
Args:
7271
model_id: The Amazon Bedrock chat model ID to use.
72+
model_provider: The Bedrock model provider to use.
7373
service_id: The Service ID for the completion service.
7474
runtime_client: The Amazon Bedrock runtime client to use.
7575
client: The Amazon Bedrock client to use.
@@ -79,6 +79,7 @@ def __init__(
7979
try:
8080
bedrock_settings = BedrockSettings(
8181
chat_model_id=model_id,
82+
model_provider=model_provider,
8283
env_file_path=env_file_path,
8384
env_file_encoding=env_file_encoding,
8485
)
@@ -93,6 +94,7 @@ def __init__(
9394
service_id=service_id or bedrock_settings.chat_model_id,
9495
runtime_client=runtime_client,
9596
client=client,
97+
bedrock_model_provider=bedrock_settings.model_provider,
9698
)
9799

98100
# region Overriding base class methods
@@ -212,7 +214,7 @@ def _prepare_settings_for_request(
212214
"stopSequences": settings.stop,
213215
}),
214216
"additionalModelRequestFields": get_chat_completion_additional_model_request_fields(
215-
self.ai_model_id, settings
217+
self.ai_model_id, settings, model_provider=self.bedrock_model_provider
216218
),
217219
}
218220

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_completion.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings
1818
from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase
1919
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import (
20+
BedrockModelProvider,
2021
get_text_completion_request_body,
2122
parse_streaming_text_completion_response,
2223
parse_text_completion_response,
@@ -41,6 +42,7 @@ class BedrockTextCompletion(BedrockBase, TextCompletionClientBase):
4142
def __init__(
4243
self,
4344
model_id: str | None = None,
45+
model_provider: BedrockModelProvider | None = None,
4446
service_id: str | None = None,
4547
runtime_client: Any | None = None,
4648
client: Any | None = None,
@@ -51,6 +53,7 @@ def __init__(
5153
5254
Args:
5355
model_id: The Amazon Bedrock text model ID to use.
56+
model_provider: The Bedrock model provider to use.
5457
service_id: The Service ID for the text completion service.
5558
runtime_client: The Amazon Bedrock runtime client to use.
5659
client: The Amazon Bedrock client to use.
@@ -60,6 +63,7 @@ def __init__(
6063
try:
6164
bedrock_settings = BedrockSettings(
6265
text_model_id=model_id,
66+
model_provider=model_provider,
6367
env_file_path=env_file_path,
6468
env_file_encoding=env_file_encoding,
6569
)
@@ -74,6 +78,7 @@ def __init__(
7478
service_id=service_id or bedrock_settings.text_model_id,
7579
runtime_client=runtime_client,
7680
client=client,
81+
bedrock_model_provider=bedrock_settings.model_provider,
7782
)
7883

7984
# region Overriding base class methods
@@ -94,11 +99,17 @@ async def _inner_get_text_contents(
9499
settings = self.get_prompt_execution_settings_from_settings(settings)
95100
assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec
96101

97-
request_body = get_text_completion_request_body(self.ai_model_id, prompt, settings)
102+
request_body = get_text_completion_request_body(
103+
self.ai_model_id,
104+
prompt,
105+
settings,
106+
model_provider=self.bedrock_model_provider,
107+
)
98108
response_body = await self._async_invoke_model(request_body)
99109
return parse_text_completion_response(
100110
self.ai_model_id,
101111
json.loads(response_body.get("body").read()),
112+
model_provider=self.bedrock_model_provider,
102113
)
103114

104115
@override
@@ -112,14 +123,20 @@ async def _inner_get_streaming_text_contents(
112123
settings = self.get_prompt_execution_settings_from_settings(settings)
113124
assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec
114125

115-
request_body = get_text_completion_request_body(self.ai_model_id, prompt, settings)
126+
request_body = get_text_completion_request_body(
127+
self.ai_model_id,
128+
prompt,
129+
settings,
130+
model_provider=self.bedrock_model_provider,
131+
)
116132
response_stream = await self._async_invoke_model_stream(request_body)
117133
for event in response_stream.get("body"):
118134
chunk = event.get("chunk")
119135
yield [
120136
parse_streaming_text_completion_response(
121137
self.ai_model_id,
122138
json.loads(chunk.get("bytes").decode()),
139+
model_provider=self.bedrock_model_provider,
123140
)
124141
]
125142

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_embedding.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from semantic_kernel.connectors.ai.bedrock.bedrock_settings import BedrockSettings
2121
from semantic_kernel.connectors.ai.bedrock.services.bedrock_base import BedrockBase
2222
from semantic_kernel.connectors.ai.bedrock.services.model_provider.bedrock_model_provider import (
23+
BedrockModelProvider,
2324
get_text_embedding_request_body,
2425
parse_text_embedding_response,
2526
)
@@ -38,6 +39,7 @@ class BedrockTextEmbedding(BedrockBase, EmbeddingGeneratorBase):
3839
def __init__(
3940
self,
4041
model_id: str | None = None,
42+
model_provider: BedrockModelProvider | None = None,
4143
service_id: str | None = None,
4244
runtime_client: Any | None = None,
4345
client: Any | None = None,
@@ -48,6 +50,7 @@ def __init__(
4850
4951
Args:
5052
model_id: The Amazon Bedrock text embedding model ID to use.
53+
model_provider: The Bedrock model provider to use.
5154
service_id: The Service ID for the text embedding service.
5255
runtime_client: The Amazon Bedrock runtime client to use.
5356
client: The Amazon Bedrock client to use.
@@ -57,6 +60,7 @@ def __init__(
5760
try:
5861
bedrock_settings = BedrockSettings(
5962
embedding_model_id=model_id,
63+
model_provider=model_provider,
6064
env_file_path=env_file_path,
6165
env_file_encoding=env_file_encoding,
6266
)
@@ -71,6 +75,7 @@ def __init__(
7175
service_id=service_id or bedrock_settings.embedding_model_id,
7276
runtime_client=runtime_client,
7377
client=client,
78+
bedrock_model_provider=bedrock_settings.model_provider,
7479
)
7580

7681
@override
@@ -87,12 +92,25 @@ async def generate_embeddings(
8792
assert isinstance(settings, BedrockEmbeddingPromptExecutionSettings) # nosec
8893

8994
results = await asyncio.gather(*[
90-
self._async_invoke_model(get_text_embedding_request_body(self.ai_model_id, text, settings))
95+
self._async_invoke_model(
96+
get_text_embedding_request_body(
97+
self.ai_model_id,
98+
text,
99+
settings,
100+
model_provider=self.bedrock_model_provider,
101+
)
102+
)
91103
for text in texts
92104
])
93105

94106
return array([
95-
array(parse_text_embedding_response(self.ai_model_id, json.loads(result.get("body").read())))
107+
array(
108+
parse_text_embedding_response(
109+
self.ai_model_id,
110+
json.loads(result.get("body").read()),
111+
model_provider=self.bedrock_model_provider,
112+
)
113+
)
96114
for result in results
97115
])
98116

python/semantic_kernel/connectors/ai/bedrock/services/model_provider/bedrock_model_provider.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,34 @@ def to_model_provider(cls, model_id: str) -> "BedrockModelProvider":
7373
}
7474

7575

76-
def get_text_completion_request_body(model_id: str, prompt: str, settings: BedrockTextPromptExecutionSettings) -> dict:
76+
def get_text_completion_request_body(
77+
model_id: str,
78+
prompt: str,
79+
settings: BedrockTextPromptExecutionSettings,
80+
model_provider: BedrockModelProvider | None = None,
81+
) -> dict:
7782
"""Get the request body for text completion for Amazon Bedrock models."""
78-
model_provider = BedrockModelProvider.to_model_provider(model_id)
83+
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
7984
return TEXT_COMPLETION_REQUEST_BODY_MAPPING[model_provider](prompt, settings)
8085

8186

82-
def parse_text_completion_response(model_id: str, response: dict) -> list[TextContent]:
87+
def parse_text_completion_response(
88+
model_id: str,
89+
response: dict,
90+
model_provider: BedrockModelProvider | None = None,
91+
) -> list[TextContent]:
8392
"""Parse the response from text completion for Amazon Bedrock models."""
84-
model_provider = BedrockModelProvider.to_model_provider(model_id)
93+
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
8594
return TEXT_COMPLETION_RESPONSE_MAPPING[model_provider](response, model_id)
8695

8796

88-
def parse_streaming_text_completion_response(model_id: str, chunk: dict) -> StreamingTextContent:
97+
def parse_streaming_text_completion_response(
98+
model_id: str,
99+
chunk: dict,
100+
model_provider: BedrockModelProvider | None = None,
101+
) -> StreamingTextContent:
89102
"""Parse the response from streaming text completion for Amazon Bedrock models."""
90-
model_provider = BedrockModelProvider.to_model_provider(model_id)
103+
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
91104
return STREAMING_TEXT_COMPLETION_RESPONSE_MAPPING[model_provider](chunk, model_id)
92105

93106

@@ -109,10 +122,12 @@ def parse_streaming_text_completion_response(model_id: str, chunk: dict) -> Stre
109122

110123

111124
def get_chat_completion_additional_model_request_fields(
112-
model_id: str, settings: BedrockChatPromptExecutionSettings
125+
model_id: str,
126+
settings: BedrockChatPromptExecutionSettings,
127+
model_provider: BedrockModelProvider | None = None,
113128
) -> dict[str, Any] | None:
114129
"""Get the additional model request fields for chat completion for Amazon Bedrock models."""
115-
model_provider = BedrockModelProvider.to_model_provider(model_id)
130+
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
116131
return CHAT_COMPLETION_ADDITIONAL_MODEL_REQUEST_FIELDS_MAPPING[model_provider](settings)
117132

118133

@@ -134,16 +149,23 @@ def get_chat_completion_additional_model_request_fields(
134149

135150

136151
def get_text_embedding_request_body(
137-
model_id: str, text: str, settings: BedrockEmbeddingPromptExecutionSettings
152+
model_id: str,
153+
text: str,
154+
settings: BedrockEmbeddingPromptExecutionSettings,
155+
model_provider: BedrockModelProvider | None = None,
138156
) -> dict:
139157
"""Get the request body for text embedding for Amazon Bedrock models."""
140-
model_provider = BedrockModelProvider.to_model_provider(model_id)
158+
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
141159
return TEXT_EMBEDDING_REQUEST_BODY_MAPPING[model_provider](text, settings)
142160

143161

144-
def parse_text_embedding_response(model_id: str, response: dict) -> list[float]:
162+
def parse_text_embedding_response(
163+
model_id: str,
164+
response: dict,
165+
model_provider: BedrockModelProvider | None = None,
166+
) -> list[float]:
145167
"""Parse the response from text embedding for Amazon Bedrock models."""
146-
model_provider = BedrockModelProvider.to_model_provider(model_id)
168+
model_provider = model_provider or BedrockModelProvider.to_model_provider(model_id)
147169
return TEXT_EMBEDDING_RESPONSE_MAPPING[model_provider](response)
148170

149171

0 commit comments

Comments
 (0)