Skip to content

Commit a1495b9

Browse files
authored
Merge branch 'main' into retry_docs
2 parents fe81968 + a8f7067 commit a1495b9

File tree

16 files changed

+214
-119
lines changed

16 files changed

+214
-119
lines changed

Makefile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ typecheck-both: typecheck-pyright typecheck-mypy
5353
.PHONY: test
5454
test: ## Run tests and collect coverage data
5555
@# To test using a specific version of python, run 'make install-all-python' then set environment variable PYTEST_PYTHON=3.10 or similar
56-
$(if $(PYTEST_PYTHON),UV_PROJECT_ENVIRONMENT=.venv$(subst .,,$(PYTEST_PYTHON))) uv run $(if $(PYTEST_PYTHON),--python $(PYTEST_PYTHON)) coverage run -m pytest -n auto --dist=loadgroup --durations=20
56+
COLUMNS=150 $(if $(PYTEST_PYTHON),UV_PROJECT_ENVIRONMENT=.venv$(subst .,,$(PYTEST_PYTHON))) uv run $(if $(PYTEST_PYTHON),--python $(PYTEST_PYTHON)) coverage run -m pytest -n auto --dist=loadgroup --durations=20
5757
@uv run coverage combine
5858
@uv run coverage report
5959

6060
.PHONY: test-all-python
6161
test-all-python: ## Run tests on Python 3.10 to 3.13
62-
UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras --all-packages coverage run -p -m pytest
63-
UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras --all-packages coverage run -p -m pytest
64-
UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras --all-packages coverage run -p -m pytest
65-
UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras --all-packages coverage run -p -m pytest
62+
COLUMNS=150 UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras --all-packages coverage run -p -m pytest
63+
COLUMNS=150 UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras --all-packages coverage run -p -m pytest
64+
COLUMNS=150 UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras --all-packages coverage run -p -m pytest
65+
COLUMNS=150 UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras --all-packages coverage run -p -m pytest
6666
@uv run coverage combine
6767
@uv run coverage report
6868

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
764764

765765
model_kind = provider_name
766766
if model_kind.startswith('gateway/'):
767-
model_kind = provider_name.removeprefix('gateway/')
767+
from ..providers.gateway import infer_gateway_model
768+
769+
return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
768770
if model_kind in (
769771
'openai',
770772
'azure',

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __init__(
226226
self._model_name = model_name
227227

228228
if isinstance(provider, str):
229-
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
229+
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
230230
self._provider = provider
231231
self.client = cast('BedrockRuntimeClient', provider.client)
232232

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
self._model_name = model_name
205205

206206
if isinstance(provider, str):
207-
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
207+
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
208208
self._provider = provider
209209
self.client = provider.client
210210

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self._model_name = model_name
376376

377377
if isinstance(provider, str):
378-
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
378+
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
379379
self._provider = provider
380380
self.client = provider.client
381381

@@ -944,7 +944,7 @@ def __init__(
944944
self._model_name = model_name
945945

946946
if isinstance(provider, str):
947-
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
947+
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
948948
self._provider = provider
949949
self.client = provider.client
950950

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
158158
if provider.startswith('gateway/'):
159159
from .gateway import gateway_provider
160160

161-
provider = provider.removeprefix('gateway/')
162-
return gateway_provider(provider)
161+
api_type = provider.removeprefix('gateway/')
162+
return gateway_provider(api_type)
163163
elif provider in ('google-vertex', 'google-gla'):
164164
from .google import GoogleProvider
165165

pydantic_ai_slim/pydantic_ai/providers/gateway.py

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from groq import AsyncGroq
1818
from openai import AsyncOpenAI
1919

20+
from pydantic_ai.models import Model
2021
from pydantic_ai.models.anthropic import AsyncAnthropicClient
2122
from pydantic_ai.providers import Provider
2223

@@ -25,8 +26,11 @@
2526

2627
@overload
2728
def gateway_provider(
28-
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'],
29+
api_type: Literal['chat', 'responses'],
30+
/,
2931
*,
32+
routing_group: str | None = None,
33+
profile: str | None = None,
3034
api_key: str | None = None,
3135
base_url: str | None = None,
3236
http_client: httpx.AsyncClient | None = None,
@@ -35,8 +39,11 @@ def gateway_provider(
3539

3640
@overload
3741
def gateway_provider(
38-
upstream_provider: Literal['groq'],
42+
api_type: Literal['groq'],
43+
/,
3944
*,
45+
routing_group: str | None = None,
46+
profile: str | None = None,
4047
api_key: str | None = None,
4148
base_url: str | None = None,
4249
http_client: httpx.AsyncClient | None = None,
@@ -45,56 +52,77 @@ def gateway_provider(
4552

4653
@overload
4754
def gateway_provider(
48-
upstream_provider: Literal['google-vertex'],
55+
api_type: Literal['anthropic'],
56+
/,
4957
*,
58+
routing_group: str | None = None,
59+
profile: str | None = None,
5060
api_key: str | None = None,
5161
base_url: str | None = None,
52-
) -> Provider[GoogleClient]: ...
62+
http_client: httpx.AsyncClient | None = None,
63+
) -> Provider[AsyncAnthropicClient]: ...
5364

5465

5566
@overload
5667
def gateway_provider(
57-
upstream_provider: Literal['anthropic'],
68+
api_type: Literal['converse'],
69+
/,
5870
*,
71+
routing_group: str | None = None,
72+
profile: str | None = None,
5973
api_key: str | None = None,
6074
base_url: str | None = None,
61-
) -> Provider[AsyncAnthropicClient]: ...
75+
) -> Provider[BaseClient]: ...
6276

6377

6478
@overload
6579
def gateway_provider(
66-
upstream_provider: Literal['bedrock'],
80+
api_type: Literal['gemini'],
81+
/,
6782
*,
83+
routing_group: str | None = None,
84+
profile: str | None = None,
6885
api_key: str | None = None,
6986
base_url: str | None = None,
70-
) -> Provider[BaseClient]: ...
87+
http_client: httpx.AsyncClient | None = None,
88+
) -> Provider[GoogleClient]: ...
7189

7290

7391
@overload
7492
def gateway_provider(
75-
upstream_provider: str,
93+
api_type: str,
94+
/,
7695
*,
96+
routing_group: str | None = None,
97+
profile: str | None = None,
7798
api_key: str | None = None,
7899
base_url: str | None = None,
79100
) -> Provider[Any]: ...
80101

81102

82-
UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
103+
APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']
83104

84105

85106
def gateway_provider(
86-
upstream_provider: UpstreamProvider | str,
107+
api_type: APIType | str,
108+
/,
87109
*,
88110
# Every provider
111+
routing_group: str | None = None,
112+
profile: str | None = None,
89113
api_key: str | None = None,
90114
base_url: str | None = None,
91-
# OpenAI, Groq & Anthropic
115+
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
92116
http_client: httpx.AsyncClient | None = None,
93117
) -> Provider[Any]:
94118
"""Create a new Gateway provider.
95119
96120
Args:
97-
upstream_provider: The upstream provider to use.
121+
api_type: Determines the API type to use.
122+
routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
123+
any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
124+
profile: A provider may have a profile, which is a unique identifier for the provider.
125+
The `pydantic-ai-gateway-profile` header will be added.
98126
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
99127
environment variable will be used if available.
100128
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -109,18 +137,24 @@ def gateway_provider(
109137
)
110138

111139
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
112-
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
140+
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
113141
http_client.event_hooks = {'request': [_request_hook(api_key)]}
114142

115-
if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
143+
if profile is not None:
144+
http_client.headers.setdefault('pydantic-ai-gateway-profile', profile)
145+
146+
if routing_group is not None:
147+
http_client.headers.setdefault('pydantic-ai-gateway-routing-group', routing_group)
148+
149+
if api_type in ('chat', 'responses'):
116150
from .openai import OpenAIProvider
117151

118-
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
119-
elif upstream_provider == 'groq':
152+
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
153+
elif api_type == 'groq':
120154
from .groq import GroqProvider
121155

122156
return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
123-
elif upstream_provider == 'anthropic':
157+
elif api_type == 'anthropic':
124158
from anthropic import AsyncAnthropic
125159

126160
from .anthropic import AnthropicProvider
@@ -132,25 +166,25 @@ def gateway_provider(
132166
http_client=http_client,
133167
)
134168
)
135-
elif upstream_provider == 'bedrock':
169+
elif api_type == 'converse':
136170
from .bedrock import BedrockProvider
137171

138172
return BedrockProvider(
139173
api_key=api_key,
140-
base_url=_merge_url_path(base_url, 'bedrock'),
174+
base_url=_merge_url_path(base_url, api_type),
141175
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
142176
)
143-
elif upstream_provider == 'google-vertex':
177+
elif api_type == 'gemini':
144178
from .google import GoogleProvider
145179

146180
return GoogleProvider(
147181
vertexai=True,
148182
api_key=api_key,
149-
base_url=_merge_url_path(base_url, 'google-vertex'),
183+
base_url=_merge_url_path(base_url, 'gemini'),
150184
http_client=http_client,
151185
)
152186
else:
153-
raise UserError(f'Unknown upstream provider: {upstream_provider}')
187+
raise UserError(f'Unknown API type: {api_type}')
154188

155189

156190
def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
@@ -182,3 +216,33 @@ def _merge_url_path(base_url: str, path: str) -> str:
182216
path: The path to merge.
183217
"""
184218
return base_url.rstrip('/') + '/' + path.lstrip('/')
219+
220+
221+
def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
222+
"""Infer the model class for a given API type."""
223+
if api_type == 'chat':
224+
from pydantic_ai.models.openai import OpenAIChatModel
225+
226+
return OpenAIChatModel(model_name=model_name, provider='gateway')
227+
elif api_type == 'groq':
228+
from pydantic_ai.models.groq import GroqModel
229+
230+
return GroqModel(model_name=model_name, provider='gateway')
231+
elif api_type == 'responses':
232+
from pydantic_ai.models.openai import OpenAIResponsesModel
233+
234+
return OpenAIResponsesModel(model_name=model_name, provider='gateway')
235+
elif api_type == 'gemini':
236+
from pydantic_ai.models.google import GoogleModel
237+
238+
return GoogleModel(model_name=model_name, provider='gateway')
239+
elif api_type == 'converse':
240+
from pydantic_ai.models.bedrock import BedrockConverseModel
241+
242+
return BedrockConverseModel(model_name=model_name, provider='gateway')
243+
elif api_type == 'anthropic':
244+
from pydantic_ai.models.anthropic import AnthropicModel
245+
246+
return AnthropicModel(model_name=model_name, provider='gateway')
247+
else:
248+
raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover

pydantic_ai_slim/pydantic_ai/ui/_adapter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from pydantic_ai import DeferredToolRequests, DeferredToolResults
2121
from pydantic_ai.agent import AbstractAgent
22+
from pydantic_ai.agent.abstract import Instructions
2223
from pydantic_ai.builtin_tools import AbstractBuiltinTool
2324
from pydantic_ai.exceptions import UserError
2425
from pydantic_ai.messages import ModelMessage
@@ -200,6 +201,7 @@ def run_stream_native(
200201
message_history: Sequence[ModelMessage] | None = None,
201202
deferred_tool_results: DeferredToolResults | None = None,
202203
model: Model | KnownModelName | str | None = None,
204+
instructions: Instructions[AgentDepsT] = None,
203205
deps: AgentDepsT = None,
204206
model_settings: ModelSettings | None = None,
205207
usage_limits: UsageLimits | None = None,
@@ -216,6 +218,7 @@ def run_stream_native(
216218
message_history: History of the conversation so far.
217219
deferred_tool_results: Optional results for deferred tool calls in the message history.
218220
model: Optional model to use for this run, required if `model` was not set when creating the agent.
221+
instructions: Optional additional instructions to use for this run.
219222
deps: Optional dependencies to use for this run.
220223
model_settings: Optional settings to use for this model's request.
221224
usage_limits: Optional limits on model request count or token usage.
@@ -251,6 +254,7 @@ def run_stream_native(
251254
model=model,
252255
deps=deps,
253256
model_settings=model_settings,
257+
instructions=instructions,
254258
usage_limits=usage_limits,
255259
usage=usage,
256260
infer_name=infer_name,
@@ -265,6 +269,7 @@ def run_stream(
265269
message_history: Sequence[ModelMessage] | None = None,
266270
deferred_tool_results: DeferredToolResults | None = None,
267271
model: Model | KnownModelName | str | None = None,
272+
instructions: Instructions[AgentDepsT] = None,
268273
deps: AgentDepsT = None,
269274
model_settings: ModelSettings | None = None,
270275
usage_limits: UsageLimits | None = None,
@@ -282,6 +287,7 @@ def run_stream(
282287
message_history: History of the conversation so far.
283288
deferred_tool_results: Optional results for deferred tool calls in the message history.
284289
model: Optional model to use for this run, required if `model` was not set when creating the agent.
290+
instructions: Optional additional instructions to use for this run.
285291
deps: Optional dependencies to use for this run.
286292
model_settings: Optional settings to use for this model's request.
287293
usage_limits: Optional limits on model request count or token usage.
@@ -298,6 +304,7 @@ def run_stream(
298304
message_history=message_history,
299305
deferred_tool_results=deferred_tool_results,
300306
model=model,
307+
instructions=instructions,
301308
deps=deps,
302309
model_settings=model_settings,
303310
usage_limits=usage_limits,
@@ -318,6 +325,7 @@ async def dispatch_request(
318325
message_history: Sequence[ModelMessage] | None = None,
319326
deferred_tool_results: DeferredToolResults | None = None,
320327
model: Model | KnownModelName | str | None = None,
328+
instructions: Instructions[AgentDepsT] = None,
321329
deps: AgentDepsT = None,
322330
output_type: OutputSpec[Any] | None = None,
323331
model_settings: ModelSettings | None = None,
@@ -338,6 +346,7 @@ async def dispatch_request(
338346
message_history: History of the conversation so far.
339347
deferred_tool_results: Optional results for deferred tool calls in the message history.
340348
model: Optional model to use for this run, required if `model` was not set when creating the agent.
349+
instructions: Optional additional instructions to use for this run.
341350
deps: Optional dependencies to use for this run.
342351
model_settings: Optional settings to use for this model's request.
343352
usage_limits: Optional limits on model request count or token usage.
@@ -375,6 +384,7 @@ async def dispatch_request(
375384
deps=deps,
376385
output_type=output_type,
377386
model=model,
387+
instructions=instructions,
378388
model_settings=model_settings,
379389
usage_limits=usage_limits,
380390
usage=usage,

0 commit comments

Comments
 (0)