Skip to content

Commit a7eab4a

Browse files
authored
Python: Remove model info check in Bedrock connectors (#12395)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Closing #10941 ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> To support more ways (via `model_id`, `inference_profile`, or `model arn`) to send requests to Bedrock models, we have removed the checks on whether a model supports certain features, i.e. streaming and embedding generation. This is because to check for if a model support certain features, we need the model id, but given the flexibility in specifying the target model, it's no longer possible to retrieve the model information realiably. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent 0e7556e commit a7eab4a

File tree

8 files changed

+3
-120
lines changed

8 files changed

+3
-120
lines changed

python/semantic_kernel/connectors/ai/bedrock/README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ Not all models in Bedrock support tools. Refer to the [AWS documentation](https:
5656

5757
Not all models in Bedrock support streaming. You can use the boto3 client to check if a model supports streaming. Refer to the [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html) and the [Boto3 documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/get_foundation_model.html) for more information.
5858

59-
You can also directly call the `get_foundation_model_info("model_id")` method from the Bedrock connector to check if a model supports streaming.
60-
61-
> Note: The bedrock connector will check if a model supports streaming before making a streaming request to the model.
62-
6359
## Model specific parameters
6460

6561
Foundation models can have specific parameters that are unique to the model or the model provider. You can refer to this [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html) for more information.

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_base.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
from abc import ABC
4-
from functools import partial
54
from typing import Any, ClassVar
65

76
import boto3
87

98
from semantic_kernel.kernel_pydantic import KernelBaseModel
10-
from semantic_kernel.utils.async_utils import run_in_executor
119

1210

1311
class BedrockBase(KernelBaseModel, ABC):
@@ -40,15 +38,3 @@ def __init__(
4038
bedrock_client=client or boto3.client("bedrock"),
4139
**kwargs,
4240
)
43-
44-
async def get_foundation_model_info(self, model_id: str) -> dict[str, Any]:
45-
"""Get the foundation model information."""
46-
response = await run_in_executor(
47-
None,
48-
partial(
49-
self.bedrock_client.get_foundation_model,
50-
modelIdentifier=model_id,
51-
),
52-
)
53-
54-
return response.get("modelDetails")

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_chat_completion.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from semantic_kernel.contents.utils.finish_reason import FinishReason
3939
from semantic_kernel.exceptions.service_exceptions import (
4040
ServiceInitializationError,
41-
ServiceInvalidRequestError,
4241
ServiceInvalidResponseError,
4342
)
4443
from semantic_kernel.utils.async_utils import run_in_executor
@@ -127,11 +126,6 @@ async def _inner_get_streaming_chat_message_contents(
127126
settings: "PromptExecutionSettings",
128127
function_invoke_attempt: int = 0,
129128
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
130-
# Not all models support streaming: check if the model supports streaming before proceeding
131-
model_info = await self.get_foundation_model_info(self.ai_model_id)
132-
if not model_info.get("responseStreamingSupported"):
133-
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support streaming.")
134-
135129
if not isinstance(settings, BedrockChatPromptExecutionSettings):
136130
settings = self.get_prompt_execution_settings_from_settings(settings)
137131
assert isinstance(settings, BedrockChatPromptExecutionSettings) # nosec

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_completion.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase
2525
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
2626
from semantic_kernel.contents.text_content import TextContent
27-
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidRequestError
27+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
2828
from semantic_kernel.utils.async_utils import run_in_executor
2929
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import (
3030
trace_streaming_text_completion,
@@ -108,11 +108,6 @@ async def _inner_get_streaming_text_contents(
108108
prompt: str,
109109
settings: "PromptExecutionSettings",
110110
) -> AsyncGenerator[list[StreamingTextContent], Any]:
111-
# Not all models support streaming: check if the model supports streaming before proceeding
112-
model_info = await self.get_foundation_model_info(self.ai_model_id)
113-
if not model_info.get("responseStreamingSupported"):
114-
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support streaming.")
115-
116111
if not isinstance(settings, BedrockTextPromptExecutionSettings):
117112
settings = self.get_prompt_execution_settings_from_settings(settings)
118113
assert isinstance(settings, BedrockTextPromptExecutionSettings) # nosec

python/semantic_kernel/connectors/ai/bedrock/services/bedrock_text_embedding.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase
2727
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
28-
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidRequestError
28+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
2929
from semantic_kernel.utils.async_utils import run_in_executor
3030

3131
if TYPE_CHECKING:
@@ -80,13 +80,6 @@ async def generate_embeddings(
8080
settings: "PromptExecutionSettings | None" = None,
8181
**kwargs: Any,
8282
) -> ndarray:
83-
model_info = await self.get_foundation_model_info(self.ai_model_id)
84-
if "TEXT" not in model_info.get("inputModalities", []):
85-
# Image embedding is not supported yet in SK
86-
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support text input.")
87-
if "EMBEDDING" not in model_info.get("outputModalities", []):
88-
raise ServiceInvalidRequestError(f"The model {self.ai_model_id} does not support embedding output.")
89-
9083
if not settings:
9184
settings = BedrockEmbeddingPromptExecutionSettings()
9285
elif not isinstance(settings, BedrockEmbeddingPromptExecutionSettings):

python/tests/unit/connectors/ai/bedrock/services/test_bedrock_chat_completion.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from semantic_kernel.contents.utils.finish_reason import FinishReason
1919
from semantic_kernel.exceptions.service_exceptions import (
2020
ServiceInitializationError,
21-
ServiceInvalidRequestError,
2221
ServiceInvalidResponseError,
2322
)
2423
from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient
@@ -281,30 +280,6 @@ async def test_bedrock_streaming_chat_completion(
281280
assert response.finish_reason == FinishReason.STOP
282281

283282

284-
async def test_bedrock_streaming_chat_completion_with_unsupported_model(
285-
model_id,
286-
chat_history: ChatHistory,
287-
) -> None:
288-
"""Test Amazon Bedrock Streaming Chat Completion complete method"""
289-
with patch.object(
290-
MockBedrockClient, "get_foundation_model", return_value={"modelDetails": {"responseStreamingSupported": False}}
291-
):
292-
# Setup
293-
bedrock_chat_completion = BedrockChatCompletion(
294-
model_id=model_id,
295-
runtime_client=MockBedrockRuntimeClient(),
296-
client=MockBedrockClient(),
297-
)
298-
299-
# Act
300-
settings = BedrockChatPromptExecutionSettings()
301-
with pytest.raises(ServiceInvalidRequestError):
302-
async for chunk in bedrock_chat_completion.get_streaming_chat_message_contents(
303-
chat_history=chat_history, settings=settings
304-
):
305-
pass
306-
307-
308283
@pytest.mark.parametrize(
309284
# These are fake model ids with the supported prefixes
310285
"model_id",

python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_completion.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
1616
from semantic_kernel.contents.text_content import TextContent
17-
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceInvalidRequestError
17+
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
1818
from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient
1919

2020
# region init
@@ -213,25 +213,4 @@ async def test_bedrock_streaming_text_completion(
213213
assert isinstance(response.inner_content, list)
214214

215215

216-
async def test_bedrock_streaming_text_completion_with_unsupported_model(
217-
model_id,
218-
) -> None:
219-
"""Test Amazon Bedrock Streaming Chat Completion complete method"""
220-
with patch.object(
221-
MockBedrockClient, "get_foundation_model", return_value={"modelDetails": {"responseStreamingSupported": False}}
222-
):
223-
# Setup
224-
bedrock_text_completion = BedrockTextCompletion(
225-
model_id=model_id,
226-
runtime_client=MockBedrockRuntimeClient(),
227-
client=MockBedrockClient(),
228-
)
229-
230-
# Act
231-
settings = BedrockTextPromptExecutionSettings()
232-
with pytest.raises(ServiceInvalidRequestError):
233-
async for chunk in bedrock_text_completion.get_streaming_text_contents("Hello", settings=settings):
234-
pass
235-
236-
237216
# endregion

python/tests/unit/connectors/ai/bedrock/services/test_bedrock_text_embedding_generation.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from semantic_kernel.connectors.ai.bedrock.services.bedrock_text_embedding import BedrockTextEmbedding
1212
from semantic_kernel.exceptions.service_exceptions import (
1313
ServiceInitializationError,
14-
ServiceInvalidRequestError,
1514
ServiceInvalidResponseError,
1615
)
1716
from tests.unit.connectors.ai.bedrock.conftest import MockBedrockClient, MockBedrockRuntimeClient
@@ -149,40 +148,6 @@ async def test_bedrock_text_embedding(model_id, mock_bedrock_text_embedding_resp
149148
assert len(response) == 2
150149

151150

152-
async def test_bedrock_text_embedding_with_unsupported_model_input_modality(model_id) -> None:
153-
"""Test Bedrock text embedding generation with unsupported model"""
154-
with patch.object(
155-
MockBedrockClient, "get_foundation_model", return_value={"modelDetails": {"inputModalities": ["IMAGE"]}}
156-
):
157-
# Setup
158-
bedrock_text_embedding = BedrockTextEmbedding(
159-
model_id=model_id,
160-
runtime_client=MockBedrockRuntimeClient(),
161-
client=MockBedrockClient(),
162-
)
163-
164-
with pytest.raises(ServiceInvalidRequestError):
165-
await bedrock_text_embedding.generate_embeddings(["hello", "world"])
166-
167-
168-
async def test_bedrock_text_embedding_with_unsupported_model_output_modality(model_id) -> None:
169-
"""Test Bedrock text embedding generation with unsupported model"""
170-
with patch.object(
171-
MockBedrockClient,
172-
"get_foundation_model",
173-
return_value={"modelDetails": {"inputModalities": ["TEXT"], "outputModalities": ["TEXT"]}},
174-
):
175-
# Setup
176-
bedrock_text_embedding = BedrockTextEmbedding(
177-
model_id=model_id,
178-
runtime_client=MockBedrockRuntimeClient(),
179-
client=MockBedrockClient(),
180-
)
181-
182-
with pytest.raises(ServiceInvalidRequestError):
183-
await bedrock_text_embedding.generate_embeddings(["hello", "world"])
184-
185-
186151
@pytest.mark.parametrize(
187152
# These are fake model ids with the supported prefixes
188153
"model_id",

0 commit comments

Comments
 (0)