Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901

model_kind = provider_name
if model_kind.startswith('gateway/'):
model_kind = provider_name.removeprefix('gateway/')
from ..providers.gateway import infer_gateway_model

return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
if model_kind in (
'openai',
'azure',
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
self._provider = provider
self.client = cast('BedrockRuntimeClient', provider.client)

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down Expand Up @@ -944,7 +944,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
if provider.startswith('gateway/'):
from .gateway import gateway_provider

provider = provider.removeprefix('gateway/')
return gateway_provider(provider)
api_type = provider.removeprefix('gateway/')
return gateway_provider(api_type)
elif provider in ('google-vertex', 'google-gla'):
from .google import GoogleProvider

Expand Down
87 changes: 63 additions & 24 deletions pydantic_ai_slim/pydantic_ai/providers/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from groq import AsyncGroq
from openai import AsyncOpenAI

from pydantic_ai.models import Model
from pydantic_ai.models.anthropic import AsyncAnthropicClient
from pydantic_ai.providers import Provider

Expand All @@ -25,7 +26,8 @@

@overload
def gateway_provider(
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'],
api_type: Literal['chat', 'responses'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
Expand All @@ -35,66 +37,73 @@ def gateway_provider(

@overload
def gateway_provider(
upstream_provider: Literal['groq'],
api_key: Literal['groq'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
http_client: httpx.AsyncClient | None = None,
) -> Provider[AsyncGroq]: ...


@overload
def gateway_provider(
upstream_provider: Literal['google-vertex'],
api_type: Literal['anthropic'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[GoogleClient]: ...
http_client: httpx.AsyncClient | None = None,
) -> Provider[AsyncAnthropicClient]: ...


@overload
def gateway_provider(
upstream_provider: Literal['anthropic'],
api_type: Literal['converse'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[AsyncAnthropicClient]: ...
) -> Provider[BaseClient]: ...


@overload
def gateway_provider(
upstream_provider: Literal['bedrock'],
api_type: Literal['gemini'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[BaseClient]: ...
http_client: httpx.AsyncClient | None = None,
) -> Provider[GoogleClient]: ...


@overload
def gateway_provider(
upstream_provider: str,
api_type: str,
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[Any]: ...


UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']


def gateway_provider(
upstream_provider: UpstreamProvider | str,
api_type: APIType | str,
/,
*,
# Every provider
api_key: str | None = None,
base_url: str | None = None,
# OpenAI, Groq & Anthropic
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
http_client: httpx.AsyncClient | None = None,
) -> Provider[Any]:
"""Create a new Gateway provider.

Args:
upstream_provider: The upstream provider to use.
api_type: Determines the API type to use.
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
environment variable will be used if available.
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
Expand All @@ -109,18 +118,18 @@ def gateway_provider(
)

base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
http_client.event_hooks = {'request': [_request_hook(api_key)]}

if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
if api_type in ('chat', 'responses'):
from .openai import OpenAIProvider

return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
elif upstream_provider == 'groq':
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
elif api_type == 'groq':
from .groq import GroqProvider

return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
elif upstream_provider == 'anthropic':
elif api_type == 'anthropic':
from anthropic import AsyncAnthropic

from .anthropic import AnthropicProvider
Expand All @@ -132,25 +141,25 @@ def gateway_provider(
http_client=http_client,
)
)
elif upstream_provider == 'bedrock':
elif api_type in 'converse':
from .bedrock import BedrockProvider

return BedrockProvider(
api_key=api_key,
base_url=_merge_url_path(base_url, 'bedrock'),
base_url=_merge_url_path(base_url, api_type),
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
)
elif upstream_provider == 'google-vertex':
elif api_type in 'gemini':
from .google import GoogleProvider

return GoogleProvider(
vertexai=True,
api_key=api_key,
base_url=_merge_url_path(base_url, 'google-vertex'),
base_url=_merge_url_path(base_url, 'gemini'),
http_client=http_client,
)
else:
raise UserError(f'Unknown upstream provider: {upstream_provider}')
raise UserError(f'Unknown API type: {api_type}')


def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
Expand Down Expand Up @@ -182,3 +191,33 @@ def _merge_url_path(base_url: str, path: str) -> str:
path: The path to merge.
"""
return base_url.rstrip('/') + '/' + path.lstrip('/')


def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
"""Infer the model class for a given API type."""
if api_type == 'chat':
from pydantic_ai.models.openai import OpenAIChatModel

return OpenAIChatModel(model_name=model_name, provider='gateway')
elif api_type == 'groq':
from pydantic_ai.models.groq import GroqModel

return GroqModel(model_name=model_name, provider='gateway')
elif api_type == 'responses':
from pydantic_ai.models.openai import OpenAIResponsesModel

return OpenAIResponsesModel(model_name=model_name, provider='gateway')
elif api_type == 'gemini':
from pydantic_ai.models.google import GoogleModel

return GoogleModel(model_name=model_name, provider='gateway')
elif api_type == 'converse':
from pydantic_ai.models.bedrock import BedrockConverseModel

return BedrockConverseModel(model_name=model_name, provider='gateway')
elif api_type == 'anthropic':
from pydantic_ai.models.anthropic import AnthropicModel

return AnthropicModel(model_name=model_name, provider='gateway')
else:
raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -134,31 +134,31 @@ interactions:
parsed_body:
data:
- created: 0
id: qwen-3-32b
id: qwen-3-235b-a22b-thinking-2507
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-235b-a22b-instruct-2507
id: llama-3.3-70b
object: model
owned_by: Cerebras
- created: 0
id: gpt-oss-120b
id: qwen-3-235b-a22b-instruct-2507
object: model
owned_by: Cerebras
- created: 0
id: zai-glm-4.6
id: qwen-3-32b
object: model
owned_by: Cerebras
- created: 0
id: llama3.1-8b
id: zai-glm-4.6
object: model
owned_by: Cerebras
- created: 0
id: llama-3.3-70b
id: gpt-oss-120b
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-235b-a22b-thinking-2507
id: llama3.1-8b
object: model
owned_by: Cerebras
object: list
Expand Down
25 changes: 8 additions & 17 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,21 @@
TEST_CASES = [
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/openai:gpt-5',
'gateway/chat:gpt-5',
'gpt-5',
'openai',
'openai',
OpenAIChatModel,
id='gateway/openai:gpt-5',
id='gateway/chat:gpt-5',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/openai-chat:gpt-5',
'gpt-5',
'openai',
'openai',
OpenAIChatModel,
id='gateway/openai-chat:gpt-5',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/openai-responses:gpt-5',
'gateway/responses:gpt-5',
'gpt-5',
'openai',
'openai',
OpenAIResponsesModel,
id='gateway/openai-responses:gpt-5',
id='gateway/responses:gpt-5',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
Expand All @@ -65,12 +56,12 @@
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/google-vertex:gemini-1.5-flash',
'gateway/gemini:gemini-1.5-flash',
'gemini-1.5-flash',
'google-vertex',
'google',
GoogleModel,
id='gateway/google-vertex:gemini-1.5-flash',
id='gateway/gemini:gemini-1.5-flash',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
Expand All @@ -83,12 +74,12 @@
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/bedrock:amazon.nova-micro-v1:0',
'gateway/converse:amazon.nova-micro-v1:0',
'amazon.nova-micro-v1:0',
'bedrock',
'bedrock',
BedrockConverseModel,
id='gateway/bedrock:amazon.nova-micro-v1:0',
id='gateway/converse:amazon.nova-micro-v1:0',
),
pytest.param(
{'OPENAI_API_KEY': 'openai-api-key'},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ interactions:
headers:
amz-sdk-invocation-id:
- !!binary |
MWYwNDlkMTQtMjVmMC00YTRhLWJhYmMtNTQ0MDdhMmRlNjgw
MmEzMzkzMGUtNzI3YS00YzFhLWFmYWQtYzFhYWMyMTI3NDlj
amz-sdk-request:
- !!binary |
YXR0ZW1wdD0x
Expand All @@ -15,7 +15,7 @@ interactions:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
method: POST
uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse
uri: http://localhost:8787/converse/model/amazon.nova-micro-v1%3A0/converse
response:
headers:
content-length:
Expand All @@ -26,7 +26,7 @@ interactions:
- 0.0000USD
parsed_body:
metrics:
latencyMs: 668
latencyMs: 682
output:
message:
content:
Expand Down
Loading