Skip to content

Commit 99581d7

Browse files
Kludexsamuelcolvin
andauthored
Add PydanticAI Gateway provider (#2816)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent bee76e6 commit 99581d7

26 files changed

+830
-96
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,11 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
718718
)
719719
provider = 'google-vertex'
720720

721-
if provider == 'cohere':
721+
if provider == 'gateway':
722+
from ..providers.gateway import infer_model as infer_model_from_gateway
723+
724+
return infer_model_from_gateway(model_name)
725+
elif provider == 'cohere':
722726
from .cohere import CohereModel
723727

724728
return CohereModel(model_name, provider=provider)

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
from mistralai.models.usermessage import UserMessage as MistralUserMessage
8383
from mistralai.types.basemodel import Unset as MistralUnset
8484
from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
85-
except ImportError as e:
85+
except ImportError as e: # pragma: lax no cover
8686
raise ImportError(
8787
'Please install `mistral` to use the Mistral model, '
8888
'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
4747
"""The model profile for the named model, if available."""
4848
return None # pragma: no cover
4949

50+
def __repr__(self) -> str:
51+
return f'{self.__class__.__name__}(name={self.name}, base_url={self.base_url})'
52+
5053

5154
def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
5255
"""Infers the provider class from the provider name."""

pydantic_ai_slim/pydantic_ai/providers/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
6969
self._client = anthropic_client
7070
else:
71-
api_key = api_key or os.environ.get('ANTHROPIC_API_KEY')
71+
api_key = api_key or os.getenv('ANTHROPIC_API_KEY')
7272
if not api_key:
7373
raise UserError(
7474
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'

pydantic_ai_slim/pydantic_ai/providers/cohere.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ def __init__(
6060
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
6161
self._client = cohere_client
6262
else:
63-
api_key = api_key or os.environ.get('CO_API_KEY')
63+
api_key = api_key or os.getenv('CO_API_KEY')
6464
if not api_key:
6565
raise UserError(
6666
'Set the `CO_API_KEY` environment variable or pass it via `CohereProvider(api_key=...)`'
6767
'to use the Cohere provider.'
6868
)
6969

70-
base_url = os.environ.get('CO_BASE_URL')
70+
base_url = os.getenv('CO_BASE_URL')
7171
if http_client is not None:
7272
self._client = AsyncClientV2(api_key=api_key, httpx_client=http_client, base_url=base_url)
7373
else:
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""This module implements the Pydantic AI Gateway provider."""
2+
3+
from __future__ import annotations as _annotations
4+
5+
import os
6+
from typing import TYPE_CHECKING, Any, Literal, overload
7+
from urllib.parse import urljoin
8+
9+
import httpx
10+
11+
from pydantic_ai.exceptions import UserError
12+
from pydantic_ai.models import Model, cached_async_http_client, get_user_agent
13+
14+
if TYPE_CHECKING:
15+
from google.genai import Client as GoogleClient
16+
from groq import AsyncGroq
17+
from openai import AsyncOpenAI
18+
19+
from pydantic_ai.providers import Provider
20+
21+
22+
@overload
23+
def gateway_provider(
24+
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'],
25+
*,
26+
api_key: str | None = None,
27+
base_url: str | None = None,
28+
http_client: httpx.AsyncClient | None = None,
29+
) -> Provider[AsyncOpenAI]: ...
30+
31+
32+
@overload
33+
def gateway_provider(
34+
upstream_provider: Literal['groq'],
35+
*,
36+
api_key: str | None = None,
37+
base_url: str | None = None,
38+
http_client: httpx.AsyncClient | None = None,
39+
) -> Provider[AsyncGroq]: ...
40+
41+
42+
@overload
43+
def gateway_provider(
44+
upstream_provider: Literal['google-vertex'],
45+
*,
46+
api_key: str | None = None,
47+
base_url: str | None = None,
48+
) -> Provider[GoogleClient]: ...
49+
50+
51+
def gateway_provider(
52+
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex'] | str,
53+
*,
54+
# Every provider
55+
api_key: str | None = None,
56+
base_url: str | None = None,
57+
# OpenAI & Groq
58+
http_client: httpx.AsyncClient | None = None,
59+
) -> Provider[Any]:
60+
"""Create a new Gateway provider.
61+
62+
Args:
63+
upstream_provider: The upstream provider to use.
64+
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
65+
environment variable will be used if available.
66+
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
67+
environment variable will be used if available. Otherwise, defaults to `http://localhost:8787/`.
68+
http_client: The HTTP client to use for the Gateway.
69+
"""
70+
api_key = api_key or os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
71+
if not api_key:
72+
raise UserError(
73+
'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(api_key=...)`'
74+
' to use the Pydantic AI Gateway provider.'
75+
)
76+
77+
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', 'http://localhost:8787')
78+
http_client = http_client or cached_async_http_client(provider=f'gateway-{upstream_provider}')
79+
http_client.event_hooks = {'request': [_request_hook]}
80+
81+
if upstream_provider in ('openai', 'openai-chat'):
82+
from .openai import OpenAIProvider
83+
84+
return OpenAIProvider(api_key=api_key, base_url=urljoin(base_url, 'openai'), http_client=http_client)
85+
elif upstream_provider == 'openai-responses':
86+
from .openai import OpenAIProvider
87+
88+
return OpenAIProvider(api_key=api_key, base_url=urljoin(base_url, 'openai'), http_client=http_client)
89+
elif upstream_provider == 'groq':
90+
from .groq import GroqProvider
91+
92+
return GroqProvider(api_key=api_key, base_url=urljoin(base_url, 'groq'), http_client=http_client)
93+
elif upstream_provider == 'google-vertex':
94+
from google.genai import Client as GoogleClient
95+
96+
from .google import GoogleProvider
97+
98+
return GoogleProvider(
99+
client=GoogleClient(
100+
vertexai=True,
101+
api_key='unset',
102+
http_options={
103+
'base_url': f'{base_url}/google-vertex',
104+
'headers': {'User-Agent': get_user_agent(), 'Authorization': api_key},
105+
# TODO(Marcelo): Until https://github.com/googleapis/python-genai/issues/1357 is solved.
106+
'async_client_args': {
107+
'transport': httpx.AsyncHTTPTransport(),
108+
'event_hooks': {'request': [_request_hook]},
109+
},
110+
},
111+
)
112+
)
113+
else: # pragma: no cover
114+
raise UserError(f'Unknown provider: {upstream_provider}')
115+
116+
117+
def infer_model(model_name: str) -> Model:
118+
"""Infer the model class that will be used to make requests to the gateway.
119+
120+
Args:
121+
model_name: The name of the model to infer. Must be in the format "provider/model_name".
122+
123+
Returns:
124+
The model class that will be used to make requests to the gateway.
125+
"""
126+
try:
127+
upstream_provider, model_name = model_name.split('/', 1)
128+
except ValueError:
129+
raise UserError(f'The model name "{model_name}" is not in the format "provider/model_name".')
130+
131+
if upstream_provider in ('openai', 'openai-chat'):
132+
from pydantic_ai.models.openai import OpenAIChatModel
133+
134+
return OpenAIChatModel(model_name, provider=gateway_provider('openai'))
135+
elif upstream_provider == 'openai-responses':
136+
from pydantic_ai.models.openai import OpenAIResponsesModel
137+
138+
return OpenAIResponsesModel(model_name, provider=gateway_provider('openai'))
139+
elif upstream_provider == 'groq':
140+
from pydantic_ai.models.groq import GroqModel
141+
142+
return GroqModel(model_name, provider=gateway_provider('groq'))
143+
elif upstream_provider == 'google-vertex':
144+
from pydantic_ai.models.google import GoogleModel
145+
146+
return GoogleModel(model_name, provider=gateway_provider('google-vertex'))
147+
raise UserError(f'Unknown upstream provider: {upstream_provider}')
148+
149+
150+
async def _request_hook(request: httpx.Request) -> httpx.Request:
151+
"""Request hook for the gateway provider.
152+
153+
It adds the `"traceparent"` header to the request.
154+
"""
155+
from opentelemetry.propagate import inject
156+
157+
headers: dict[str, Any] = {}
158+
inject(headers)
159+
request.headers.update(headers)
160+
161+
return request

pydantic_ai_slim/pydantic_ai/providers/google.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,13 @@ def __init__(
106106
else:
107107
self._client = Client(
108108
vertexai=vertexai,
109-
project=project or os.environ.get('GOOGLE_CLOUD_PROJECT'),
109+
project=project or os.getenv('GOOGLE_CLOUD_PROJECT'),
110110
# From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149:
111111
# Currently `us-central1` supports the most models by far of any region including `global`, but not
112112
# all of them. `us-central1` has all google models but is missing some Anthropic partner models,
113113
# which use `us-east5` instead. `global` has fewer models but higher availability.
114114
# For more details, check: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
115-
location=location or os.environ.get('GOOGLE_CLOUD_LOCATION') or 'us-central1',
115+
location=location or os.getenv('GOOGLE_CLOUD_LOCATION') or 'us-central1',
116116
credentials=credentials,
117117
http_options=http_options,
118118
)

pydantic_ai_slim/pydantic_ai/providers/google_gla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient |
3939
will be used if available.
4040
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
4141
"""
42-
api_key = api_key or os.environ.get('GEMINI_API_KEY')
42+
api_key = api_key or os.getenv('GEMINI_API_KEY')
4343
if not api_key:
4444
raise UserError(
4545
'Set the `GEMINI_API_KEY` environment variable or pass it via `GoogleGLAProvider(api_key=...)`'

pydantic_ai_slim/pydantic_ai/providers/groq.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def name(self) -> str:
5353

5454
@property
5555
def base_url(self) -> str:
56-
return os.environ.get('GROQ_BASE_URL', 'https://api.groq.com')
56+
return str(self.client.base_url)
5757

5858
@property
5959
def client(self) -> AsyncGroq:
@@ -85,12 +85,15 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
8585
def __init__(self, *, groq_client: AsyncGroq | None = None) -> None: ...
8686

8787
@overload
88-
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
88+
def __init__(
89+
self, *, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None
90+
) -> None: ...
8991

9092
def __init__(
9193
self,
9294
*,
9395
api_key: str | None = None,
96+
base_url: str | None = None,
9497
groq_client: AsyncGroq | None = None,
9598
http_client: httpx.AsyncClient | None = None,
9699
) -> None:
@@ -99,6 +102,8 @@ def __init__(
99102
Args:
100103
api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable
101104
will be used if available.
105+
base_url: The base url for the Groq requests. If not provided, the `GROQ_BASE_URL` environment variable
106+
will be used if available. Otherwise, defaults to Groq's base url.
102107
groq_client: An existing
103108
[`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage)
104109
client to use. If provided, `api_key` and `http_client` must be `None`.
@@ -107,17 +112,19 @@ def __init__(
107112
if groq_client is not None:
108113
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
109114
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
115+
assert base_url is None, 'Cannot provide both `groq_client` and `base_url`'
110116
self._client = groq_client
111117
else:
112-
api_key = api_key or os.environ.get('GROQ_API_KEY')
118+
api_key = api_key or os.getenv('GROQ_API_KEY')
119+
base_url = base_url or os.getenv('GROQ_BASE_URL', 'https://api.groq.com')
113120

114121
if not api_key:
115122
raise UserError(
116123
'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
117124
'to use the Groq provider.'
118125
)
119126
elif http_client is not None:
120-
self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
127+
self._client = AsyncGroq(base_url=base_url, api_key=api_key, http_client=http_client)
121128
else:
122129
http_client = cached_async_http_client(provider='groq')
123-
self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
130+
self._client = AsyncGroq(base_url=base_url, api_key=api_key, http_client=http_client)

pydantic_ai_slim/pydantic_ai/providers/heroku.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ def __init__(
6565
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
6666
self._client = openai_client
6767
else:
68-
api_key = api_key or os.environ.get('HEROKU_INFERENCE_KEY')
68+
api_key = api_key or os.getenv('HEROKU_INFERENCE_KEY')
6969
if not api_key:
7070
raise UserError(
7171
'Set the `HEROKU_INFERENCE_KEY` environment variable or pass it via `HerokuProvider(api_key=...)`'
7272
'to use the Heroku provider.'
7373
)
7474

75-
base_url = base_url or os.environ.get('HEROKU_INFERENCE_URL', 'https://us.inference.heroku.com')
75+
base_url = base_url or os.getenv('HEROKU_INFERENCE_URL', 'https://us.inference.heroku.com')
7676
base_url = base_url.rstrip('/') + '/v1'
7777

7878
if http_client is not None:

0 commit comments

Comments
 (0)