Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901

model_kind = provider_name
if model_kind.startswith('gateway/'):
model_kind = provider_name.removeprefix('gateway/')
from ..providers.gateway import infer_gateway_model

return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
if model_kind in (
'openai',
'azure',
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
self._provider = provider
self.client = cast('BedrockRuntimeClient', provider.client)

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down Expand Up @@ -944,7 +944,7 @@ def __init__(
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
self._provider = provider
self.client = provider.client

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
if provider.startswith('gateway/'):
from .gateway import gateway_provider

provider = provider.removeprefix('gateway/')
return gateway_provider(provider)
api_type = provider.removeprefix('gateway/')
return gateway_provider(api_type)
elif provider in ('google-vertex', 'google-gla'):
from .google import GoogleProvider

Expand Down
86 changes: 63 additions & 23 deletions pydantic_ai_slim/pydantic_ai/providers/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from groq import AsyncGroq
from openai import AsyncOpenAI

from pydantic_ai.models import Model
from pydantic_ai.models.anthropic import AsyncAnthropicClient
from pydantic_ai.providers import Provider

Expand All @@ -25,7 +26,8 @@

@overload
def gateway_provider(
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'],
api_type: Literal['chat', 'responses'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
Expand All @@ -35,7 +37,8 @@ def gateway_provider(

@overload
def gateway_provider(
upstream_provider: Literal['groq'],
api_type: Literal['groq'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
Expand All @@ -45,56 +48,63 @@ def gateway_provider(

@overload
def gateway_provider(
upstream_provider: Literal['google-vertex'],
api_type: Literal['anthropic'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[GoogleClient]: ...
http_client: httpx.AsyncClient | None = None,
) -> Provider[AsyncAnthropicClient]: ...


@overload
def gateway_provider(
upstream_provider: Literal['anthropic'],
api_type: Literal['converse'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[AsyncAnthropicClient]: ...
) -> Provider[BaseClient]: ...


@overload
def gateway_provider(
upstream_provider: Literal['bedrock'],
api_type: Literal['gemini'],
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[BaseClient]: ...
http_client: httpx.AsyncClient | None = None,
) -> Provider[GoogleClient]: ...


@overload
def gateway_provider(
upstream_provider: str,
api_type: str,
/,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[Any]: ...


UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']


def gateway_provider(
upstream_provider: UpstreamProvider | str,
api_type: APIType | str,
/,
*,
# Every provider
api_key: str | None = None,
base_url: str | None = None,
# OpenAI, Groq & Anthropic
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
http_client: httpx.AsyncClient | None = None,
) -> Provider[Any]:
"""Create a new Gateway provider.

Args:
upstream_provider: The upstream provider to use.
api_type: Determines the API type to use.
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
environment variable will be used if available.
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
Expand All @@ -109,18 +119,18 @@ def gateway_provider(
)

base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
http_client.event_hooks = {'request': [_request_hook(api_key)]}

if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
if api_type in ('chat', 'responses'):
from .openai import OpenAIProvider

return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
elif upstream_provider == 'groq':
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
elif api_type == 'groq':
from .groq import GroqProvider

return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
elif upstream_provider == 'anthropic':
elif api_type == 'anthropic':
from anthropic import AsyncAnthropic

from .anthropic import AnthropicProvider
Expand All @@ -132,25 +142,25 @@ def gateway_provider(
http_client=http_client,
)
)
elif upstream_provider == 'bedrock':
elif api_type == 'converse':
from .bedrock import BedrockProvider

return BedrockProvider(
api_key=api_key,
base_url=_merge_url_path(base_url, 'bedrock'),
base_url=_merge_url_path(base_url, api_type),
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
)
elif upstream_provider == 'google-vertex':
elif api_type == 'gemini':
from .google import GoogleProvider

return GoogleProvider(
vertexai=True,
api_key=api_key,
base_url=_merge_url_path(base_url, 'google-vertex'),
base_url=_merge_url_path(base_url, 'gemini'),
http_client=http_client,
)
else:
raise UserError(f'Unknown upstream provider: {upstream_provider}')
raise UserError(f'Unknown API type: {api_type}')


def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
Expand Down Expand Up @@ -182,3 +192,33 @@ def _merge_url_path(base_url: str, path: str) -> str:
path: The path to merge.
"""
return base_url.rstrip('/') + '/' + path.lstrip('/')


def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
"""Infer the model class for a given API type."""
if api_type == 'chat':
from pydantic_ai.models.openai import OpenAIChatModel

return OpenAIChatModel(model_name=model_name, provider='gateway')
elif api_type == 'groq':
from pydantic_ai.models.groq import GroqModel

return GroqModel(model_name=model_name, provider='gateway')
elif api_type == 'responses':
from pydantic_ai.models.openai import OpenAIResponsesModel

return OpenAIResponsesModel(model_name=model_name, provider='gateway')
elif api_type == 'gemini':
from pydantic_ai.models.google import GoogleModel

return GoogleModel(model_name=model_name, provider='gateway')
elif api_type == 'converse':
from pydantic_ai.models.bedrock import BedrockConverseModel

return BedrockConverseModel(model_name=model_name, provider='gateway')
elif api_type == 'anthropic':
from pydantic_ai.models.anthropic import AnthropicModel

return AnthropicModel(model_name=model_name, provider='gateway')
else:
raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -134,31 +134,31 @@ interactions:
parsed_body:
data:
- created: 0
id: qwen-3-32b
id: qwen-3-235b-a22b-thinking-2507
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-235b-a22b-instruct-2507
id: llama-3.3-70b
object: model
owned_by: Cerebras
- created: 0
id: gpt-oss-120b
id: qwen-3-235b-a22b-instruct-2507
object: model
owned_by: Cerebras
- created: 0
id: zai-glm-4.6
id: qwen-3-32b
object: model
owned_by: Cerebras
- created: 0
id: llama3.1-8b
id: zai-glm-4.6
object: model
owned_by: Cerebras
- created: 0
id: llama-3.3-70b
id: gpt-oss-120b
object: model
owned_by: Cerebras
- created: 0
id: qwen-3-235b-a22b-thinking-2507
id: llama3.1-8b
object: model
owned_by: Cerebras
object: list
Expand Down
25 changes: 8 additions & 17 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,21 @@
TEST_CASES = [
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/openai:gpt-5',
'gateway/chat:gpt-5',
'gpt-5',
'openai',
'openai',
OpenAIChatModel,
id='gateway/openai:gpt-5',
id='gateway/chat:gpt-5',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/openai-chat:gpt-5',
'gpt-5',
'openai',
'openai',
OpenAIChatModel,
id='gateway/openai-chat:gpt-5',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/openai-responses:gpt-5',
'gateway/responses:gpt-5',
'gpt-5',
'openai',
'openai',
OpenAIResponsesModel,
id='gateway/openai-responses:gpt-5',
id='gateway/responses:gpt-5',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
Expand All @@ -65,12 +56,12 @@
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/google-vertex:gemini-1.5-flash',
'gateway/gemini:gemini-1.5-flash',
'gemini-1.5-flash',
'google-vertex',
'google',
GoogleModel,
id='gateway/google-vertex:gemini-1.5-flash',
id='gateway/gemini:gemini-1.5-flash',
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
Expand All @@ -83,12 +74,12 @@
),
pytest.param(
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
'gateway/bedrock:amazon.nova-micro-v1:0',
'gateway/converse:amazon.nova-micro-v1:0',
'amazon.nova-micro-v1:0',
'bedrock',
'bedrock',
BedrockConverseModel,
id='gateway/bedrock:amazon.nova-micro-v1:0',
id='gateway/converse:amazon.nova-micro-v1:0',
),
pytest.param(
{'OPENAI_API_KEY': 'openai-api-key'},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ interactions:
headers:
amz-sdk-invocation-id:
- !!binary |
MWYwNDlkMTQtMjVmMC00YTRhLWJhYmMtNTQ0MDdhMmRlNjgw
MmEzMzkzMGUtNzI3YS00YzFhLWFmYWQtYzFhYWMyMTI3NDlj
amz-sdk-request:
- !!binary |
YXR0ZW1wdD0x
Expand All @@ -15,7 +15,7 @@ interactions:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
method: POST
uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse
uri: http://localhost:8787/converse/model/amazon.nova-micro-v1%3A0/converse
response:
headers:
content-length:
Expand All @@ -26,7 +26,7 @@ interactions:
- 0.0000USD
parsed_body:
metrics:
latencyMs: 668
latencyMs: 682
output:
message:
content:
Expand Down
Loading