Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,7 @@ def infer_model( # noqa: C901

model_kind = provider_name
if model_kind.startswith('gateway/'):
from ..providers.gateway import infer_gateway_model

return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
model_kind = provider_name.removeprefix('gateway/')
if model_kind in (
'openai',
'azure',
Expand Down
106 changes: 33 additions & 73 deletions pydantic_ai_slim/pydantic_ai/providers/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from groq import AsyncGroq
from openai import AsyncOpenAI

from pydantic_ai.models import Model
from pydantic_ai.models.anthropic import AsyncAnthropicClient
from pydantic_ai.providers import Provider

Expand All @@ -26,11 +25,10 @@

@overload
def gateway_provider(
api_type: Literal['chat', 'responses'],
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -39,11 +37,10 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: Literal['groq'],
upstream_provider: Literal['groq'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -52,11 +49,10 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: Literal['anthropic'],
upstream_provider: Literal['anthropic'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -65,23 +61,21 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: Literal['converse'],
upstream_provider: Literal['bedrock'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[BaseClient]: ...


@overload
def gateway_provider(
api_type: Literal['gemini'],
upstream_provider: Literal['gemini'],
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
Expand All @@ -90,26 +84,24 @@ def gateway_provider(

@overload
def gateway_provider(
api_type: str,
upstream_provider: str,
/,
*,
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[Any]: ...


APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']
UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'anthropic', 'bedrock', 'google-vertex']


def gateway_provider(
api_type: APIType | str,
upstream_provider: UpstreamProvider | str,
/,
*,
# Every provider
routing_group: str | None = None,
profile: str | None = None,
route: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
Expand All @@ -118,11 +110,8 @@ def gateway_provider(
"""Create a new Gateway provider.

Args:
api_type: Determines the API type to use.
routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
profile: A provider may have a profile, which is a unique identifier for the provider.
The `pydantic-ai-gateway-profile` header will be added.
upstream_provider: The upstream provider to use.
route: <DESCRIPTION>.
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
environment variable will be used if available.
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
Expand All @@ -137,24 +126,25 @@ def gateway_provider(
)

base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
http_client.event_hooks = {'request': [_request_hook(api_key)]}

if profile is not None:
http_client.headers.setdefault('pydantic-ai-gateway-profile', profile)
if route is not None:
http_client.headers.setdefault('pydantic-ai-gateway-route', route)

if routing_group is not None:
http_client.headers.setdefault('pydantic-ai-gateway-routing-group', routing_group)

if api_type in ('chat', 'responses'):
if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
from .openai import OpenAIProvider

return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
elif api_type == 'groq':
return OpenAIProvider(
api_key=api_key,
base_url=_merge_url_path(base_url, upstream_provider),
http_client=http_client,
)
elif upstream_provider == 'groq':
from .groq import GroqProvider

return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
elif api_type == 'anthropic':
elif upstream_provider == 'anthropic':
from anthropic import AsyncAnthropic

from .anthropic import AnthropicProvider
Expand All @@ -166,25 +156,25 @@ def gateway_provider(
http_client=http_client,
)
)
elif api_type == 'converse':
elif upstream_provider == 'bedrock':
from .bedrock import BedrockProvider

return BedrockProvider(
api_key=api_key,
base_url=_merge_url_path(base_url, api_type),
base_url=_merge_url_path(base_url, upstream_provider),
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
)
elif api_type == 'gemini':
elif upstream_provider == 'google-vertex':
from .google import GoogleProvider

return GoogleProvider(
vertexai=True,
api_key=api_key,
base_url=_merge_url_path(base_url, 'gemini'),
base_url=_merge_url_path(base_url, upstream_provider),
http_client=http_client,
)
else:
raise UserError(f'Unknown API type: {api_type}')
raise UserError(f'Unknown upstream provider: {upstream_provider}')


def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
Expand Down Expand Up @@ -216,33 +206,3 @@ def _merge_url_path(base_url: str, path: str) -> str:
path: The path to merge.
"""
return base_url.rstrip('/') + '/' + path.lstrip('/')


def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
"""Infer the model class for a given API type."""
if api_type == 'chat':
from pydantic_ai.models.openai import OpenAIChatModel

return OpenAIChatModel(model_name=model_name, provider='gateway')
elif api_type == 'groq':
from pydantic_ai.models.groq import GroqModel

return GroqModel(model_name=model_name, provider='gateway')
elif api_type == 'responses':
from pydantic_ai.models.openai import OpenAIResponsesModel

return OpenAIResponsesModel(model_name=model_name, provider='gateway')
elif api_type == 'gemini':
from pydantic_ai.models.google import GoogleModel

return GoogleModel(model_name=model_name, provider='gateway')
elif api_type == 'converse':
from pydantic_ai.models.bedrock import BedrockConverseModel

return BedrockConverseModel(model_name=model_name, provider='gateway')
elif api_type == 'anthropic':
from pydantic_ai.models.anthropic import AnthropicModel

return AnthropicModel(model_name=model_name, provider='gateway')
else:
raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover
Loading