Skip to content

Commit 889abfb

Browse files
authored
feat(gateway): support profile and routing_group (#3361)
1 parent 57a3aa7 commit 889abfb

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

pydantic_ai_slim/pydantic_ai/providers/gateway.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def gateway_provider(
2929
api_type: Literal['chat', 'responses'],
3030
/,
3131
*,
32+
routing_group: str | None = None,
33+
profile: str | None = None,
3234
api_key: str | None = None,
3335
base_url: str | None = None,
3436
http_client: httpx.AsyncClient | None = None,
@@ -40,6 +42,8 @@ def gateway_provider(
4042
api_type: Literal['groq'],
4143
/,
4244
*,
45+
routing_group: str | None = None,
46+
profile: str | None = None,
4347
api_key: str | None = None,
4448
base_url: str | None = None,
4549
http_client: httpx.AsyncClient | None = None,
@@ -51,6 +55,8 @@ def gateway_provider(
5155
api_type: Literal['anthropic'],
5256
/,
5357
*,
58+
routing_group: str | None = None,
59+
profile: str | None = None,
5460
api_key: str | None = None,
5561
base_url: str | None = None,
5662
http_client: httpx.AsyncClient | None = None,
@@ -62,6 +68,8 @@ def gateway_provider(
6268
api_type: Literal['converse'],
6369
/,
6470
*,
71+
routing_group: str | None = None,
72+
profile: str | None = None,
6573
api_key: str | None = None,
6674
base_url: str | None = None,
6775
) -> Provider[BaseClient]: ...
@@ -72,6 +80,8 @@ def gateway_provider(
7280
api_type: Literal['gemini'],
7381
/,
7482
*,
83+
routing_group: str | None = None,
84+
profile: str | None = None,
7585
api_key: str | None = None,
7686
base_url: str | None = None,
7787
http_client: httpx.AsyncClient | None = None,
@@ -83,6 +93,8 @@ def gateway_provider(
8393
api_type: str,
8494
/,
8595
*,
96+
routing_group: str | None = None,
97+
profile: str | None = None,
8698
api_key: str | None = None,
8799
base_url: str | None = None,
88100
) -> Provider[Any]: ...
@@ -96,6 +108,8 @@ def gateway_provider(
96108
/,
97109
*,
98110
# Every provider
111+
routing_group: str | None = None,
112+
profile: str | None = None,
99113
api_key: str | None = None,
100114
base_url: str | None = None,
101115
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
@@ -105,6 +119,10 @@ def gateway_provider(
105119
106120
Args:
107121
api_type: Determines the API type to use.
122+
routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
123+
any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
124+
profile: A provider may have a profile, which is a unique identifier for the provider.
125+
The `pydantic-ai-gateway-profile` header will be added.
108126
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
109127
environment variable will be used if available.
110128
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -122,6 +140,12 @@ def gateway_provider(
122140
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
123141
http_client.event_hooks = {'request': [_request_hook(api_key)]}
124142

143+
if profile is not None:
144+
http_client.headers.setdefault('pydantic-ai-gateway-profile', profile)
145+
146+
if routing_group is not None:
147+
http_client.headers.setdefault('pydantic-ai-gateway-routing-group', routing_group)
148+
125149
if api_type in ('chat', 'responses'):
126150
from .openai import OpenAIProvider
127151

tests/providers/test_gateway.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,15 @@ async def test_model_provider_argument():
180180

181181
model = BedrockConverseModel('amazon.nova-micro-v1:0', provider='gateway')
182182
assert GATEWAY_BASE_URL in model._provider.base_url # type: ignore[reportPrivateUsage]
183+
184+
185+
async def test_gateway_provider_routing_group_header(gateway_api_key: str):
186+
provider = gateway_provider('chat', routing_group='openai', api_key=gateway_api_key)
187+
httpx_client = provider.client._client # type: ignore[reportPrivateUsage]
188+
assert httpx_client.headers['pydantic-ai-gateway-routing-group'] == 'openai'
189+
190+
191+
async def test_gateway_provider_profile_header(gateway_api_key: str):
192+
provider = gateway_provider('chat', profile='openai', api_key=gateway_api_key)
193+
httpx_client = provider.client._client # type: ignore[reportPrivateUsage]
194+
assert httpx_client.headers['pydantic-ai-gateway-profile'] == 'openai'

0 commit comments

Comments
 (0)