Skip to content

Commit b16c8ad

Browse files
authored
Merge branch 'main' into gateway-string-docs
2 parents a70a74d + 1df9ca6 commit b16c8ad

21 files changed

+431
-183
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ jobs:
202202
strategy:
203203
fail-fast: false
204204
matrix:
205-
python-version: ["3.10", "3.11", "3.12", "3.13"]
205+
# TODO(Marcelo): Enable 3.11 again.
206+
python-version: ["3.10", "3.12", "3.13"]
206207
env:
207208
CI: true
208209
COVERAGE_PROCESS_START: ./pyproject.toml

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@
103103
'bedrock:us.anthropic.claude-opus-4-20250514-v1:0',
104104
'bedrock:anthropic.claude-sonnet-4-20250514-v1:0',
105105
'bedrock:us.anthropic.claude-sonnet-4-20250514-v1:0',
106+
'bedrock:eu.anthropic.claude-sonnet-4-20250514-v1:0',
107+
'bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0',
108+
'bedrock:us.anthropic.claude-sonnet-4-5-20250929-v1:0',
109+
'bedrock:eu.anthropic.claude-sonnet-4-5-20250929-v1:0',
110+
'bedrock:anthropic.claude-haiku-4-5-20251001-v1:0',
111+
'bedrock:us.anthropic.claude-haiku-4-5-20251001-v1:0',
112+
'bedrock:eu.anthropic.claude-haiku-4-5-20251001-v1:0',
106113
'bedrock:cohere.command-text-v14',
107114
'bedrock:cohere.command-r-v1:0',
108115
'bedrock:cohere.command-r-plus-v1:0',
@@ -773,9 +780,10 @@ def infer_model( # noqa: C901
773780

774781
model_kind = provider_name
775782
if model_kind.startswith('gateway/'):
776-
from ..providers.gateway import infer_gateway_model
783+
from ..providers.gateway import normalize_gateway_provider
777784

778-
return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
785+
model_kind = provider_name.removeprefix('gateway/')
786+
model_kind = normalize_gateway_provider(model_kind)
779787
if model_kind in (
780788
'openai',
781789
'azure',

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
DocumentUrl,
2323
FinishReason,
2424
ImageUrl,
25-
ModelHTTPError,
2625
ModelMessage,
2726
ModelProfileSpec,
2827
ModelRequest,
@@ -41,7 +40,7 @@
4140
usage,
4241
)
4342
from pydantic_ai._run_context import RunContext
44-
from pydantic_ai.exceptions import UserError
43+
from pydantic_ai.exceptions import ModelHTTPError, UserError
4544
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
4645
from pydantic_ai.providers import Provider, infer_provider
4746
from pydantic_ai.providers.bedrock import BedrockModelProfile
@@ -61,6 +60,7 @@
6160
ConverseStreamMetadataEventTypeDef,
6261
ConverseStreamOutputTypeDef,
6362
ConverseStreamResponseTypeDef,
63+
CountTokensRequestTypeDef,
6464
DocumentBlockTypeDef,
6565
GuardrailConfigurationTypeDef,
6666
ImageBlockTypeDef,
@@ -77,7 +77,6 @@
7777
VideoBlockTypeDef,
7878
)
7979

80-
8180
LatestBedrockModelNames = Literal[
8281
'amazon.titan-tg1-large',
8382
'amazon.titan-text-lite-v1',
@@ -106,6 +105,13 @@
106105
'us.anthropic.claude-opus-4-20250514-v1:0',
107106
'anthropic.claude-sonnet-4-20250514-v1:0',
108107
'us.anthropic.claude-sonnet-4-20250514-v1:0',
108+
'eu.anthropic.claude-sonnet-4-20250514-v1:0',
109+
'anthropic.claude-sonnet-4-5-20250929-v1:0',
110+
'us.anthropic.claude-sonnet-4-5-20250929-v1:0',
111+
'eu.anthropic.claude-sonnet-4-5-20250929-v1:0',
112+
'anthropic.claude-haiku-4-5-20251001-v1:0',
113+
'us.anthropic.claude-haiku-4-5-20251001-v1:0',
114+
'eu.anthropic.claude-haiku-4-5-20251001-v1:0',
109115
'cohere.command-text-v14',
110116
'cohere.command-r-v1:0',
111117
'cohere.command-r-plus-v1:0',
@@ -136,7 +142,6 @@
136142
See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for a full list.
137143
"""
138144

139-
140145
P = ParamSpec('P')
141146
T = typing.TypeVar('T')
142147

@@ -149,6 +154,13 @@
149154
'tool_use': 'tool_call',
150155
}
151156

157+
_AWS_BEDROCK_INFERENCE_GEO_PREFIXES: tuple[str, ...] = ('us.', 'eu.', 'apac.', 'jp.', 'au.', 'ca.')
158+
"""Geo prefixes for Bedrock inference profile IDs (e.g., 'eu.', 'us.').
159+
160+
Used to strip the geo prefix so we can pass a pure foundation model ID/ARN to CountTokens,
161+
which does not accept profile IDs. Extend if new geos appear (e.g., 'global.', 'us-gov.').
162+
"""
163+
152164

153165
class BedrockModelSettings(ModelSettings, total=False):
154166
"""Settings for Bedrock models.
@@ -228,7 +240,7 @@ def __init__(
228240
self._model_name = model_name
229241

230242
if isinstance(provider, str):
231-
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
243+
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
232244
self._provider = provider
233245
self.client = cast('BedrockRuntimeClient', provider.client)
234246

@@ -275,6 +287,34 @@ async def request(
275287
model_response = await self._process_response(response)
276288
return model_response
277289

290+
async def count_tokens(
291+
self,
292+
messages: list[ModelMessage],
293+
model_settings: ModelSettings | None,
294+
model_request_parameters: ModelRequestParameters,
295+
) -> usage.RequestUsage:
296+
"""Count the number of tokens, works with limited models.
297+
298+
Check the actual supported models on <https://docs.aws.amazon.com/bedrock/latest/userguide/count-tokens.html>
299+
"""
300+
model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters)
301+
system_prompt, bedrock_messages = await self._map_messages(messages, model_request_parameters)
302+
params: CountTokensRequestTypeDef = {
303+
'modelId': self._remove_inference_geo_prefix(self.model_name),
304+
'input': {
305+
'converse': {
306+
'messages': bedrock_messages,
307+
'system': system_prompt,
308+
},
309+
},
310+
}
311+
try:
312+
response = await anyio.to_thread.run_sync(functools.partial(self.client.count_tokens, **params))
313+
except ClientError as e:
314+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
315+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
316+
return usage.RequestUsage(input_tokens=response['inputTokens'])
317+
278318
@asynccontextmanager
279319
async def request_stream(
280320
self,
@@ -642,6 +682,14 @@ def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
642682
'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
643683
}
644684

685+
@staticmethod
686+
def _remove_inference_geo_prefix(model_name: BedrockModelName) -> BedrockModelName:
687+
"""Remove inference geographic prefix from model ID if present."""
688+
for prefix in _AWS_BEDROCK_INFERENCE_GEO_PREFIXES:
689+
if model_name.startswith(prefix):
690+
return model_name.removeprefix(prefix)
691+
return model_name
692+
645693

646694
@dataclass
647695
class BedrockStreamedResponse(StreamedResponse):

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
self._model_name = model_name
205205

206206
if isinstance(provider, str):
207-
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
207+
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
208208
self._provider = provider
209209
self.client = provider.client
210210

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self._model_name = model_name
376376

377377
if isinstance(provider, str):
378-
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
378+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
379379
self._provider = provider
380380
self.client = provider.client
381381

@@ -944,7 +944,7 @@ def __init__(
944944
self._model_name = model_name
945945

946946
if isinstance(provider, str):
947-
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
947+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
948948
self._provider = provider
949949
self.client = provider.client
950950

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
158158
if provider.startswith('gateway/'):
159159
from .gateway import gateway_provider
160160

161-
api_type = provider.removeprefix('gateway/')
162-
return gateway_provider(api_type)
161+
upstream_provider = provider.removeprefix('gateway/')
162+
return gateway_provider(upstream_provider)
163163
elif provider in ('google-vertex', 'google-gla'):
164164
from .google import GoogleProvider
165165

0 commit comments

Comments
 (0)