diff --git a/pydantic_ai_slim/pydantic_ai/providers/gateway.py b/pydantic_ai_slim/pydantic_ai/providers/gateway.py index 7f28aeb606..588c982209 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/gateway.py +++ b/pydantic_ai_slim/pydantic_ai/providers/gateway.py @@ -29,6 +29,8 @@ def gateway_provider( api_type: Literal['chat', 'responses'], /, *, + routing_group: str | None = None, + profile: str | None = None, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None, @@ -40,6 +42,8 @@ def gateway_provider( api_type: Literal['groq'], /, *, + routing_group: str | None = None, + profile: str | None = None, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None, @@ -51,6 +55,8 @@ def gateway_provider( api_type: Literal['anthropic'], /, *, + routing_group: str | None = None, + profile: str | None = None, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None, @@ -62,6 +68,8 @@ def gateway_provider( api_type: Literal['converse'], /, *, + routing_group: str | None = None, + profile: str | None = None, api_key: str | None = None, base_url: str | None = None, ) -> Provider[BaseClient]: ... @@ -72,6 +80,8 @@ def gateway_provider( api_type: Literal['gemini'], /, *, + routing_group: str | None = None, + profile: str | None = None, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None, @@ -83,6 +93,8 @@ def gateway_provider( api_type: str, /, *, + routing_group: str | None = None, + profile: str | None = None, api_key: str | None = None, base_url: str | None = None, ) -> Provider[Any]: ... @@ -96,6 +108,8 @@ def gateway_provider( /, *, # Every provider + routing_group: str | None = None, + profile: str | None = None, api_key: str | None = None, base_url: str | None = None, # OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client. @@ -105,6 +119,10 @@ def gateway_provider( Args: api_type: Determines the API type to use. + routing_group: The group of APIs that support the same models - the idea is that you can route the requests to + any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added. + profile: A provider may have a profile, which is a unique identifier for the provider. + The `pydantic-ai-gateway-profile` header will be added. api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable will be used if available. base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL` @@ -122,6 +140,12 @@ def gateway_provider( http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}') http_client.event_hooks = {'request': [_request_hook(api_key)]} + if profile is not None: + http_client.headers.setdefault('pydantic-ai-gateway-profile', profile) + + if routing_group is not None: + http_client.headers.setdefault('pydantic-ai-gateway-routing-group', routing_group) + if api_type in ('chat', 'responses'): from .openai import OpenAIProvider diff --git a/tests/providers/test_gateway.py b/tests/providers/test_gateway.py index 6b3b16001d..76fbdf5a00 100644 --- a/tests/providers/test_gateway.py +++ b/tests/providers/test_gateway.py @@ -180,3 +180,15 @@ async def test_model_provider_argument(): model = BedrockConverseModel('amazon.nova-micro-v1:0', provider='gateway') assert GATEWAY_BASE_URL in model._provider.base_url # type: ignore[reportPrivateUsage] + + +async def test_gateway_provider_routing_group_header(gateway_api_key: str): + provider = gateway_provider('chat', routing_group='openai', api_key=gateway_api_key) + httpx_client = provider.client._client # type: ignore[reportPrivateUsage] + assert httpx_client.headers['pydantic-ai-gateway-routing-group'] == 'openai' + + +async def test_gateway_provider_profile_header(gateway_api_key: str): + provider = gateway_provider('chat', profile='openai', api_key=gateway_api_key) + httpx_client = provider.client._client # type: ignore[reportPrivateUsage] + assert httpx_client.headers['pydantic-ai-gateway-profile'] == 'openai'