Skip to content
24 changes: 24 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def gateway_provider(
api_type: Literal['chat', 'responses'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -40,6 +42,8 @@ def gateway_provider(
api_type: Literal['groq'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -51,6 +55,8 @@ def gateway_provider(
api_type: Literal['anthropic'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -62,6 +68,8 @@ def gateway_provider(
api_type: Literal['converse'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[BaseClient]: ...
Expand All @@ -72,6 +80,8 @@ def gateway_provider(
api_type: Literal['gemini'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -83,6 +93,8 @@ def gateway_provider(
api_type: str,
/,
*,
routing_group: str | None = None,
profile: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[Any]: ...
Expand All @@ -96,6 +108,8 @@ def gateway_provider(
/,
*,
# Every provider
routing_group: str | None = None,
profile: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
Expand All @@ -105,6 +119,10 @@ def gateway_provider(

Args:
api_type: Determines the API type to use.
routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
profile: A provider may have a profile, which is a unique identifier for the provider.
The `pydantic-ai-gateway-profile` header will be added.
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
environment variable will be used if available.
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
Expand All @@ -122,6 +140,12 @@ def gateway_provider(
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
http_client.event_hooks = {'request': [_request_hook(api_key)]}

if profile is not None:
http_client.headers.setdefault('pydantic-ai-gateway-profile', profile)

if routing_group is not None:
http_client.headers.setdefault('pydantic-ai-gateway-routing-group', routing_group)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Cached Client Headers: Stale Values Persist

Using setdefault to add profile and routing_group headers causes incorrect behavior when the cached HTTP client is reused. Since cached_async_http_client returns the same client instance for the same api_type, subsequent calls with different profile or routing_group values won't update the headers because setdefault only sets values for keys that don't exist. This means the second call will incorrectly use headers from the first call instead of the newly specified values.

Fix in Cursor Fix in Web

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair.


if api_type in ('chat', 'responses'):
from .openai import OpenAIProvider

Expand Down
12 changes: 12 additions & 0 deletions tests/providers/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,15 @@ async def test_model_provider_argument():

model = BedrockConverseModel('amazon.nova-micro-v1:0', provider='gateway')
assert GATEWAY_BASE_URL in model._provider.base_url # type: ignore[reportPrivateUsage]


async def test_gateway_provider_routing_group_header(gateway_api_key: str):
provider = gateway_provider('chat', routing_group='openai', api_key=gateway_api_key)
httpx_client = provider.client._client # type: ignore[reportPrivateUsage]
assert httpx_client.headers['pydantic-ai-gateway-routing-group'] == 'openai'


async def test_gateway_provider_profile_header(gateway_api_key: str):
provider = gateway_provider('chat', profile='openai', api_key=gateway_api_key)
httpx_client = provider.client._client # type: ignore[reportPrivateUsage]
assert httpx_client.headers['pydantic-ai-gateway-profile'] == 'openai'