Skip to content

Commit 3a86855

Browse files
KludexDouweM
andauthored
feat(gateway): support AWS Bedrock (#3203)
Co-authored-by: Douwe Maan <[email protected]>
1 parent b06bd6c commit 3a86855

File tree

5 files changed

+147
-20
lines changed

5 files changed

+147
-20
lines changed

pydantic_ai_slim/pydantic_ai/providers/bedrock.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from collections.abc import Callable
66
from dataclasses import dataclass
7-
from typing import Literal, overload
7+
from typing import Any, Literal, overload
88

99
from pydantic_ai import ModelProfile
1010
from pydantic_ai.exceptions import UserError
@@ -21,6 +21,8 @@
2121
from botocore.client import BaseClient
2222
from botocore.config import Config
2323
from botocore.exceptions import NoRegionError
24+
from botocore.session import Session
25+
from botocore.tokens import FrozenAuthToken
2426
except ImportError as _import_error:
2527
raise ImportError(
2628
'Please install the `boto3` package to use the Bedrock provider, '
@@ -117,10 +119,23 @@ def __init__(self, *, bedrock_client: BaseClient) -> None: ...
117119
def __init__(
118120
self,
119121
*,
122+
api_key: str,
123+
base_url: str | None = None,
120124
region_name: str | None = None,
125+
profile_name: str | None = None,
126+
aws_read_timeout: float | None = None,
127+
aws_connect_timeout: float | None = None,
128+
) -> None: ...
129+
130+
@overload
131+
def __init__(
132+
self,
133+
*,
121134
aws_access_key_id: str | None = None,
122135
aws_secret_access_key: str | None = None,
123136
aws_session_token: str | None = None,
137+
base_url: str | None = None,
138+
region_name: str | None = None,
124139
profile_name: str | None = None,
125140
aws_read_timeout: float | None = None,
126141
aws_connect_timeout: float | None = None,
@@ -130,42 +145,71 @@ def __init__(
130145
self,
131146
*,
132147
bedrock_client: BaseClient | None = None,
133-
region_name: str | None = None,
134148
aws_access_key_id: str | None = None,
135149
aws_secret_access_key: str | None = None,
136150
aws_session_token: str | None = None,
151+
base_url: str | None = None,
152+
region_name: str | None = None,
137153
profile_name: str | None = None,
154+
api_key: str | None = None,
138155
aws_read_timeout: float | None = None,
139156
aws_connect_timeout: float | None = None,
140157
) -> None:
141158
"""Initialize the Bedrock provider.
142159
143160
Args:
144161
bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
145-
region_name: The AWS region name.
146-
aws_access_key_id: The AWS access key ID.
147-
aws_secret_access_key: The AWS secret access key.
148-
aws_session_token: The AWS session token.
162+
aws_access_key_id: The AWS access key ID. If not set, the `AWS_ACCESS_KEY_ID` environment variable will be used if available.
163+
aws_secret_access_key: The AWS secret access key. If not set, the `AWS_SECRET_ACCESS_KEY` environment variable will be used if available.
164+
aws_session_token: The AWS session token. If not set, the `AWS_SESSION_TOKEN` environment variable will be used if available.
165+
api_key: The API key for Bedrock client. Can be used instead of `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`. If not set, the `AWS_BEARER_TOKEN_BEDROCK` environment variable will be used if available.
166+
base_url: The base URL for the Bedrock client.
167+
region_name: The AWS region name. If not set, the `AWS_DEFAULT_REGION` environment variable will be used if available.
149168
profile_name: The AWS profile name.
150169
aws_read_timeout: The read timeout for Bedrock client.
151170
aws_connect_timeout: The connect timeout for Bedrock client.
152171
"""
153172
if bedrock_client is not None:
154173
self._client = bedrock_client
155174
else:
175+
read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
176+
connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
177+
config: dict[str, Any] = {
178+
'read_timeout': read_timeout,
179+
'connect_timeout': connect_timeout,
180+
}
156181
try:
157-
read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
158-
connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
159-
session = boto3.Session(
160-
aws_access_key_id=aws_access_key_id,
161-
aws_secret_access_key=aws_secret_access_key,
162-
aws_session_token=aws_session_token,
163-
region_name=region_name,
164-
profile_name=profile_name,
165-
)
182+
if api_key is not None:
183+
session = boto3.Session(
184+
botocore_session=_BearerTokenSession(api_key),
185+
region_name=region_name,
186+
profile_name=profile_name,
187+
)
188+
config['signature_version'] = 'bearer'
189+
else:
190+
session = boto3.Session(
191+
aws_access_key_id=aws_access_key_id,
192+
aws_secret_access_key=aws_secret_access_key,
193+
aws_session_token=aws_session_token,
194+
region_name=region_name,
195+
profile_name=profile_name,
196+
)
166197
self._client = session.client( # type: ignore[reportUnknownMemberType]
167198
'bedrock-runtime',
168-
config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout),
199+
config=Config(**config),
200+
endpoint_url=base_url,
169201
)
170202
except NoRegionError as exc: # pragma: no cover
171203
raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
204+
205+
206+
class _BearerTokenSession(Session):
207+
def __init__(self, token: str):
208+
super().__init__()
209+
self.token = token
210+
211+
def get_auth_token(self, **_kwargs: Any) -> FrozenAuthToken:
212+
return FrozenAuthToken(self.token)
213+
214+
def get_credentials(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
215+
return None

pydantic_ai_slim/pydantic_ai/providers/gateway.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic_ai.models import Model, cached_async_http_client, get_user_agent
1212

1313
if TYPE_CHECKING:
14+
from botocore.client import BaseClient
1415
from google.genai import Client as GoogleClient
1516
from groq import AsyncGroq
1617
from openai import AsyncOpenAI
@@ -57,13 +58,25 @@ def gateway_provider(
5758
) -> Provider[AsyncAnthropicClient]: ...
5859

5960

61+
@overload
6062
def gateway_provider(
61-
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic'] | str,
63+
upstream_provider: Literal['bedrock'],
64+
*,
65+
api_key: str | None = None,
66+
base_url: str | None = None,
67+
) -> Provider[BaseClient]: ...
68+
69+
70+
UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
71+
72+
73+
def gateway_provider(
74+
upstream_provider: UpstreamProvider | str,
6275
*,
6376
# Every provider
6477
api_key: str | None = None,
6578
base_url: str | None = None,
66-
# OpenAI & Groq
79+
# OpenAI, Groq & Anthropic
6780
http_client: httpx.AsyncClient | None = None,
6881
) -> Provider[Any]:
6982
"""Create a new Gateway provider.
@@ -73,7 +86,7 @@ def gateway_provider(
7386
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
7487
environment variable will be used if available.
7588
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
76-
environment variable will be used if available. Otherwise, defaults to `http://localhost:8787/`.
89+
environment variable will be used if available. Otherwise, defaults to `https://gateway.pydantic.dev/proxy`.
7790
http_client: The HTTP client to use for the Gateway.
7891
"""
7992
api_key = api_key or os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
@@ -111,6 +124,14 @@ def gateway_provider(
111124
http_client=http_client,
112125
)
113126
)
127+
elif upstream_provider == 'bedrock':
128+
from .bedrock import BedrockProvider
129+
130+
return BedrockProvider(
131+
api_key=api_key,
132+
base_url=_merge_url_path(base_url, 'bedrock'),
133+
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
134+
)
114135
elif upstream_provider == 'google-vertex':
115136
from google.genai import Client as GoogleClient
116137

tests/providers/cassettes/test_gateway/test_gateway_provider_with_anthropic.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ interactions:
4040
content:
4141
- text: The capital of France is Paris.
4242
type: text
43-
id: msg_015tco2dv5oh9rFq1PcZAduv
43+
id: msg_0116L5r52AZ42YhvvdUuHEsk
4444
model: claude-3-5-sonnet-20241022
4545
role: assistant
4646
stop_reason: end_turn
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "user", "content": [{"text": "What is the capital of France?"}]}], "system": [], "inferenceConfig":
4+
{}}'
5+
headers:
6+
amz-sdk-invocation-id:
7+
- !!binary |
8+
MWYwNDlkMTQtMjVmMC00YTRhLWJhYmMtNTQ0MDdhMmRlNjgw
9+
amz-sdk-request:
10+
- !!binary |
11+
YXR0ZW1wdD0x
12+
content-length:
13+
- '126'
14+
content-type:
15+
- !!binary |
16+
YXBwbGljYXRpb24vanNvbg==
17+
method: POST
18+
uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse
19+
response:
20+
headers:
21+
content-length:
22+
- '741'
23+
content-type:
24+
- application/json
25+
pydantic-ai-gateway-price-estimate:
26+
- 0.0000USD
27+
parsed_body:
28+
metrics:
29+
latencyMs: 668
30+
output:
31+
message:
32+
content:
33+
- text: The capital of France is Paris. Paris is not only the capital city but also the most populous city in France,
34+
and it is a major center for culture, commerce, fashion, and international diplomacy. The city is known for
35+
its historical and architectural landmarks, including the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral,
36+
and the Champs-Élysées. Paris plays a significant role in the global arts, fashion, research, technology, education,
37+
and entertainment scenes.
38+
role: assistant
39+
stopReason: end_turn
40+
usage:
41+
inputTokens: 7
42+
outputTokens: 96
43+
pydantic_ai_gateway:
44+
cost_estimate: 1.3685000000000002e-05
45+
serverToolUsage: {}
46+
totalTokens: 103
47+
status:
48+
code: 200
49+
message: OK
50+
version: 1

tests/providers/test_gateway.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
with try_import() as imports_successful:
1616
from pydantic_ai.models.anthropic import AnthropicModel
17+
from pydantic_ai.models.bedrock import BedrockConverseModel
1718
from pydantic_ai.models.google import GoogleModel
1819
from pydantic_ai.models.groq import GroqModel
1920
from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel
@@ -150,3 +151,14 @@ async def test_gateway_provider_with_anthropic(allow_model_requests: None, gatew
150151

151152
result = await agent.run('What is the capital of France?')
152153
assert result.output == snapshot('The capital of France is Paris.')
154+
155+
156+
async def test_gateway_provider_with_bedrock(allow_model_requests: None, gateway_api_key: str):
157+
provider = gateway_provider('bedrock', api_key=gateway_api_key, base_url='http://localhost:8787')
158+
model = BedrockConverseModel('amazon.nova-micro-v1:0', provider=provider)
159+
agent = Agent(model)
160+
161+
result = await agent.run('What is the capital of France?')
162+
assert result.output == snapshot(
163+
'The capital of France is Paris. Paris is not only the capital city but also the most populous city in France, and it is a major center for culture, commerce, fashion, and international diplomacy. The city is known for its historical and architectural landmarks, including the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and the Champs-Élysées. Paris plays a significant role in the global arts, fashion, research, technology, education, and entertainment scenes.'
164+
)

0 commit comments

Comments
 (0)