Skip to content

Commit dafa104

Browse files
committed
Support api_type
1 parent 2217218 commit dafa104

File tree

10 files changed

+133
-122
lines changed

10 files changed

+133
-122
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __init__(
226226
self._model_name = model_name
227227

228228
if isinstance(provider, str):
229-
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
229+
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
230230
self._provider = provider
231231
self.client = cast('BedrockRuntimeClient', provider.client)
232232

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
self._model_name = model_name
205205

206206
if isinstance(provider, str):
207-
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
207+
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
208208
self._provider = provider
209209
self.client = provider.client
210210

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self._model_name = model_name
376376

377377
if isinstance(provider, str):
378-
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
378+
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
379379
self._provider = provider
380380
self.client = provider.client
381381

@@ -944,7 +944,7 @@ def __init__(
944944
self._model_name = model_name
945945

946946
if isinstance(provider, str):
947-
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
947+
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
948948
self._provider = provider
949949
self.client = provider.client
950950

pydantic_ai_slim/pydantic_ai/providers/gateway.py

Lines changed: 24 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import TYPE_CHECKING, Any, Literal, overload
88

99
import httpx
10-
from typing_extensions import deprecated
1110

1211
from pydantic_ai.exceptions import UserError
1312
from pydantic_ai.models import cached_async_http_client
@@ -25,9 +24,8 @@
2524

2625

2726
@overload
28-
@deprecated('Use `chat` or `responses` API type instead of `openai` provider type.')
2927
def gateway_provider(
30-
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'],
28+
api_type: Literal['chat', 'responses'],
3129
/,
3230
*,
3331
api_key: str | None = None,
@@ -38,76 +36,61 @@ def gateway_provider(
3836

3937
@overload
4038
def gateway_provider(
41-
upstream_provider: Literal['groq'],
39+
api_key: Literal['groq'],
4240
/,
4341
*,
44-
api_key: str | None = None,
4542
base_url: str | None = None,
4643
http_client: httpx.AsyncClient | None = None,
4744
) -> Provider[AsyncGroq]: ...
4845

4946

5047
@overload
51-
@deprecated('Use `gemini` or `anthropic` API type instead of `google-vertex` provider type.')
5248
def gateway_provider(
53-
upstream_provider: Literal['google-vertex'],
49+
api_type: Literal['anthropic'],
5450
/,
5551
*,
5652
api_key: str | None = None,
5753
base_url: str | None = None,
58-
) -> Provider[GoogleClient]: ...
54+
http_client: httpx.AsyncClient | None = None,
55+
) -> Provider[AsyncAnthropicClient]: ...
5956

6057

6158
@overload
6259
def gateway_provider(
63-
upstream_provider: Literal['anthropic'],
60+
api_type: Literal['converse'],
6461
/,
6562
*,
6663
api_key: str | None = None,
6764
base_url: str | None = None,
68-
) -> Provider[AsyncAnthropicClient]: ...
65+
) -> Provider[BaseClient]: ...
6966

7067

7168
@overload
72-
@deprecated('Use `converse` or `anthropic` API type instead of `bedrock` provider type.')
7369
def gateway_provider(
74-
upstream_provider: Literal['bedrock'],
70+
api_type: Literal['gemini'],
7571
/,
7672
*,
7773
api_key: str | None = None,
7874
base_url: str | None = None,
79-
) -> Provider[BaseClient]: ...
75+
http_client: httpx.AsyncClient | None = None,
76+
) -> Provider[GoogleClient]: ...
8077

8178

8279
@overload
8380
def gateway_provider(
84-
upstream_provider: str,
81+
api_type: str,
8582
/,
8683
*,
8784
api_key: str | None = None,
8885
base_url: str | None = None,
8986
) -> Provider[Any]: ...
9087

9188

92-
UpstreamProvider = Literal[
93-
'chat',
94-
'responses',
95-
'gemini',
96-
'converse',
97-
# Deprecated
98-
'openai',
99-
'openai-chat',
100-
'openai-responses',
101-
'google-vertex',
102-
'bedrock',
103-
# Those two are actually the same on both.
104-
'anthropic',
105-
'groq',
106-
]
89+
APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']
10790

10891

10992
def gateway_provider(
110-
upstream_provider: UpstreamProvider | str,
93+
api_type: APIType | str,
11194
/,
11295
*,
11396
# Every provider
@@ -119,7 +102,7 @@ def gateway_provider(
119102
"""Create a new Gateway provider.
120103
121104
Args:
122-
upstream_provider: The upstream provider to use.
105+
api_type: Determines the API type to use.
123106
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
124107
environment variable will be used if available.
125108
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -134,23 +117,18 @@ def gateway_provider(
134117
)
135118

136119
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
137-
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
120+
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
138121
http_client.event_hooks = {'request': [_request_hook(api_key)]}
139122

140-
if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
141-
from .openai import OpenAIProvider
142-
143-
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
144-
elif upstream_provider in ('chat', 'responses'):
123+
if api_type in ('chat', 'responses'):
145124
from .openai import OpenAIProvider
146125

147-
base_url = _merge_url_path(base_url, upstream_provider)
148-
return OpenAIProvider(api_key=api_key, base_url=base_url, http_client=http_client)
149-
elif upstream_provider == 'groq':
126+
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
127+
elif api_type == 'groq':
150128
from .groq import GroqProvider
151129

152130
return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
153-
elif upstream_provider == 'anthropic':
131+
elif api_type == 'anthropic':
154132
from anthropic import AsyncAnthropic
155133

156134
from .anthropic import AnthropicProvider
@@ -162,25 +140,25 @@ def gateway_provider(
162140
http_client=http_client,
163141
)
164142
)
165-
elif upstream_provider in ('bedrock', 'converse'):
143+
elif api_type in 'converse':
166144
from .bedrock import BedrockProvider
167145

168146
return BedrockProvider(
169147
api_key=api_key,
170-
base_url=_merge_url_path(base_url, upstream_provider),
148+
base_url=_merge_url_path(base_url, api_type),
171149
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
172150
)
173-
elif upstream_provider in ('google-vertex', 'gemini'):
151+
elif api_type in 'gemini':
174152
from .google import GoogleProvider
175153

176154
return GoogleProvider(
177155
vertexai=True,
178156
api_key=api_key,
179-
base_url=_merge_url_path(base_url, 'google-vertex'),
157+
base_url=_merge_url_path(base_url, 'gemini'),
180158
http_client=http_client,
181159
)
182160
else:
183-
raise UserError(f'Unknown upstream provider: {upstream_provider}')
161+
raise UserError(f'Unknown API type: {api_type}')
184162

185163

186164
def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:

tests/models/cassettes/test_model_names/test_known_model_names.yaml

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ interactions:
1313
response:
1414
headers:
1515
content-length:
16-
- '760'
16+
- '1080'
1717
content-security-policy:
1818
- default-src 'none'; frame-ancestors 'none'
1919
content-type:
@@ -25,6 +25,12 @@ interactions:
2525
strict-transport-security:
2626
- max-age=63072000
2727
parsed_body:
28+
- model_id: amazon-rerank-1-0
29+
regions:
30+
- eu
31+
- us
32+
type:
33+
- text-to-text
2834
- model_id: claude-3-5-haiku
2935
regions:
3036
- us
@@ -46,6 +52,18 @@ interactions:
4652
- eu
4753
type:
4854
- text-to-text
55+
- model_id: claude-4-5-haiku
56+
regions:
57+
- eu
58+
- us
59+
type:
60+
- text-to-text
61+
- model_id: claude-4-5-sonnet
62+
regions:
63+
- eu
64+
- us
65+
type:
66+
- text-to-text
4967
- model_id: claude-4-sonnet
5068
regions:
5169
- eu
@@ -58,8 +76,15 @@ interactions:
5876
- us
5977
type:
6078
- text-to-embedding
79+
- model_id: cohere-rerank-3-5
80+
regions:
81+
- eu
82+
- us
83+
type:
84+
- text-to-text
6185
- model_id: gpt-oss-120b
6286
regions:
87+
- eu
6388
- us
6489
type:
6590
- text-to-text
@@ -99,7 +124,7 @@ interactions:
99124
alt-svc:
100125
- h3=":443"; ma=86400
101126
content-length:
102-
- '762'
127+
- '570'
103128
content-type:
104129
- application/json
105130
referrer-policy:
@@ -109,41 +134,33 @@ interactions:
109134
parsed_body:
110135
data:
111136
- created: 0
112-
id: llama-4-scout-17b-16e-instruct
137+
id: zai-glm-4.6
113138
object: model
114139
owned_by: Cerebras
115140
- created: 0
116-
id: qwen-3-32b
141+
id: qwen-3-235b-a22b-instruct-2507
117142
object: model
118143
owned_by: Cerebras
119144
- created: 0
120145
id: llama-3.3-70b
121146
object: model
122147
owned_by: Cerebras
123148
- created: 0
124-
id: qwen-3-235b-a22b-instruct-2507
125-
object: model
126-
owned_by: Cerebras
127-
- created: 0
128-
id: llama-4-maverick-17b-128e-instruct
149+
id: gpt-oss-120b
129150
object: model
130151
owned_by: Cerebras
131152
- created: 0
132-
id: qwen-3-coder-480b
153+
id: qwen-3-32b
133154
object: model
134155
owned_by: Cerebras
135156
- created: 0
136-
id: gpt-oss-120b
157+
id: llama3.1-8b
137158
object: model
138159
owned_by: Cerebras
139160
- created: 0
140161
id: qwen-3-235b-a22b-thinking-2507
141162
object: model
142163
owned_by: Cerebras
143-
- created: 0
144-
id: llama3.1-8b
145-
object: model
146-
owned_by: Cerebras
147164
object: list
148165
status:
149166
code: 200

tests/providers/cassettes/test_gateway/test_gateway_provider_with_bedrock.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ interactions:
55
headers:
66
amz-sdk-invocation-id:
77
- !!binary |
8-
MWYwNDlkMTQtMjVmMC00YTRhLWJhYmMtNTQ0MDdhMmRlNjgw
8+
MmEzMzkzMGUtNzI3YS00YzFhLWFmYWQtYzFhYWMyMTI3NDlj
99
amz-sdk-request:
1010
- !!binary |
1111
YXR0ZW1wdD0x
@@ -15,7 +15,7 @@ interactions:
1515
- !!binary |
1616
YXBwbGljYXRpb24vanNvbg==
1717
method: POST
18-
uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse
18+
uri: http://localhost:8787/converse/model/amazon.nova-micro-v1%3A0/converse
1919
response:
2020
headers:
2121
content-length:
@@ -26,7 +26,7 @@ interactions:
2626
- 0.0000USD
2727
parsed_body:
2828
metrics:
29-
latencyMs: 668
29+
latencyMs: 682
3030
output:
3131
message:
3232
content:

0 commit comments

Comments
 (0)