Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,9 +780,10 @@ def infer_model( # noqa: C901

model_kind = provider_name
if model_kind.startswith('gateway/'):
from ..providers.gateway import infer_gateway_model
from ..providers.gateway import normalize_gateway_provider

return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
model_kind = provider_name.removeprefix('gateway/')
model_kind = normalize_gateway_provider(model_kind)
if model_kind in (
'openai',
'azure',
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
self._provider = provider
self.client = cast('BedrockRuntimeClient', provider.client)

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down Expand Up @@ -944,7 +944,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
if provider.startswith('gateway/'):
from .gateway import gateway_provider

api_type = provider.removeprefix('gateway/')
return gateway_provider(api_type)
upstream_provider = provider.removeprefix('gateway/')
return gateway_provider(upstream_provider)
elif provider in ('google-vertex', 'google-gla'):
from .google import GoogleProvider

Expand Down
141 changes: 61 additions & 80 deletions pydantic_ai_slim/pydantic_ai/providers/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from groq import AsyncGroq
from openai import AsyncOpenAI

from pydantic_ai.models import Model
from pydantic_ai.models.anthropic import AsyncAnthropicClient
from pydantic_ai.providers import Provider

Expand All @@ -26,11 +25,10 @@

@overload
def gateway_provider(
api_type: Literal['chat', 'responses'],
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'chat', 'responses'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -39,11 +37,10 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: Literal['groq'],
upstream_provider: Literal['groq'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -52,11 +49,10 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: Literal['anthropic'],
upstream_provider: Literal['anthropic'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -65,23 +61,21 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: Literal['converse'],
upstream_provider: Literal['bedrock', 'converse'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[BaseClient]: ...


@overload
def gateway_provider(
api_type: Literal['gemini'],
upstream_provider: Literal['gemini', 'google-vertex'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -90,26 +84,37 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: str,
upstream_provider: str,
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[Any]: ...


APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']
UpstreamProvider = Literal[
'openai',
'groq',
'anthropic',
'bedrock',
'google-vertex',
# Those are only API formats, but we still support them for convenience.
'openai-chat',
'openai-responses',
'chat',
'responses',
'converse',
'gemini',
]


def gateway_provider(
api_type: APIType | str,
upstream_provider: UpstreamProvider | str,
/,
*,
# Every provider
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
Expand All @@ -118,11 +123,9 @@ def gateway_provider(
"""Create a new Gateway provider.

Args:
api_type: Determines the API type to use.
routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
profile: A provider may have a profile, which is a unique identifier for the provider.
The `pydantic-ai-gateway-profile` header will be added.
upstream_provider: The upstream provider to use.
route: The name of the provider or routing group to use to handle the request. If not provided, the default
routing group for the API format will be used.
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
environment variable will be used if available.
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
Expand All @@ -137,54 +140,45 @@ def gateway_provider(
)

base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
http_client.event_hooks = {'request': [_request_hook(api_key)]}

if profile is not None:
http_client.headers.setdefault('pydantic-ai-gateway-profile', profile)
if route is None:
# Use the implied providerId as the default route.
route = normalize_gateway_provider(upstream_provider)

if routing_group is not None:
http_client.headers.setdefault('pydantic-ai-gateway-routing-group', routing_group)
base_url = _merge_url_path(base_url, route)

if api_type in ('chat', 'responses'):
if upstream_provider in ('openai', 'openai-chat', 'openai-responses', 'chat', 'responses'):
from .openai import OpenAIProvider

return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
elif api_type == 'groq':
return OpenAIProvider(api_key=api_key, base_url=base_url, http_client=http_client)
elif upstream_provider == 'groq':
from .groq import GroqProvider

return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
elif api_type == 'anthropic':
return GroqProvider(api_key=api_key, base_url=base_url, http_client=http_client)
elif upstream_provider == 'anthropic':
from anthropic import AsyncAnthropic

from .anthropic import AnthropicProvider

return AnthropicProvider(
anthropic_client=AsyncAnthropic(
auth_token=api_key,
base_url=_merge_url_path(base_url, 'anthropic'),
http_client=http_client,
)
anthropic_client=AsyncAnthropic(auth_token=api_key, base_url=base_url, http_client=http_client)
)
elif api_type == 'converse':
elif upstream_provider in ('bedrock', 'converse'):
from .bedrock import BedrockProvider

return BedrockProvider(
api_key=api_key,
base_url=_merge_url_path(base_url, api_type),
base_url=base_url,
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
)
elif api_type == 'gemini':
elif upstream_provider in ('google-vertex', 'gemini'):
from .google import GoogleProvider

return GoogleProvider(
vertexai=True,
api_key=api_key,
base_url=_merge_url_path(base_url, 'gemini'),
http_client=http_client,
)
return GoogleProvider(vertexai=True, api_key=api_key, base_url=base_url, http_client=http_client)
else:
raise UserError(f'Unknown API type: {api_type}')
raise UserError(f'Unknown upstream provider: {upstream_provider}')


def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
Expand Down Expand Up @@ -218,31 +212,18 @@ def _merge_url_path(base_url: str, path: str) -> str:
return base_url.rstrip('/') + '/' + path.lstrip('/')


def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
"""Infer the model class for a given API type."""
if api_type == 'chat':
from pydantic_ai.models.openai import OpenAIChatModel

return OpenAIChatModel(model_name=model_name, provider='gateway')
elif api_type == 'groq':
from pydantic_ai.models.groq import GroqModel

return GroqModel(model_name=model_name, provider='gateway')
elif api_type == 'responses':
from pydantic_ai.models.openai import OpenAIResponsesModel

return OpenAIResponsesModel(model_name=model_name, provider='gateway')
elif api_type == 'gemini':
from pydantic_ai.models.google import GoogleModel
def normalize_gateway_provider(provider: str) -> str:
"""Normalize a gateway provider name.

return GoogleModel(model_name=model_name, provider='gateway')
elif api_type == 'converse':
from pydantic_ai.models.bedrock import BedrockConverseModel

return BedrockConverseModel(model_name=model_name, provider='gateway')
elif api_type == 'anthropic':
from pydantic_ai.models.anthropic import AnthropicModel

return AnthropicModel(model_name=model_name, provider='gateway')
else:
raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover
Args:
provider: The provider name to normalize.
"""
if provider in ('openai', 'openai-chat', 'chat'):
return 'openai'
elif provider in ('openai-responses', 'responses'):
return 'openai-responses'
elif provider in ('gemini', 'google-vertex'):
return 'google-vertex'
elif provider in ('bedrock', 'converse'):
return 'bedrock'
return provider
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ interactions:
connection:
- keep-alive
content-length:
- '166'
- '159'
content-type:
- application/json
host:
Expand All @@ -32,6 +32,8 @@ interactions:
- application/json
pydantic-ai-gateway-price-estimate:
- 0.0002USD
retry-after:
- '34'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
transfer-encoding:
Expand All @@ -40,7 +42,7 @@ interactions:
content:
- text: The capital of France is Paris.
type: text
id: msg_0116L5r52AZ42YhvvdUuHEsk
id: msg_015jjU4Q5dqhSc9vyfCdoujx
model: claude-sonnet-4-5-20250929
role: assistant
stop_reason: end_turn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ interactions:
headers:
amz-sdk-invocation-id:
- !!binary |
MmEzMzkzMGUtNzI3YS00YzFhLWFmYWQtYzFhYWMyMTI3NDlj
MTU4ODY4OTctOGU4MC00YzJlLWEyZTctMDA2ZmM0NTZjMmYy
amz-sdk-request:
- !!binary |
YXR0ZW1wdD0x
Expand All @@ -15,35 +15,34 @@ interactions:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
method: POST
uri: http://localhost:8787/converse/model/amazon.nova-micro-v1%3A0/converse
uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse
response:
headers:
content-length:
- '741'
- '631'
content-type:
- application/json
pydantic-ai-gateway-price-estimate:
- 0.0000USD
parsed_body:
metrics:
latencyMs: 682
latencyMs: 717
output:
message:
content:
- text: The capital of France is Paris. Paris is not only the capital city but also the most populous city in France,
and it is a major center for culture, commerce, fashion, and international diplomacy. The city is known for
its historical and architectural landmarks, including the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral,
and the Champs-Élysées. Paris plays a significant role in the global arts, fashion, research, technology, education,
and entertainment scenes.
its historical landmarks, such as the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and the Champs-Élysées,
among many other attractions.
role: assistant
stopReason: end_turn
usage:
inputTokens: 7
outputTokens: 96
outputTokens: 78
pydantic_ai_gateway:
cost_estimate: 1.3685000000000002e-05
cost_estimate: 1.1165000000000002e-05
serverToolUsage: {}
totalTokens: 103
totalTokens: 85
status:
code: 200
message: OK
Expand Down
Loading
Loading