Skip to content

Commit f5cc9e3

Browse files
authored
Use gateway/<upstream_provider>: as provider name prefix for Gateway (#3229)
1 parent 59a7c70 commit f5cc9e3

File tree

13 files changed

+218
-154
lines changed

13 files changed

+218
-154
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from ..output import OutputMode
4545
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
46+
from ..providers import infer_provider
4647
from ..settings import ModelSettings, merge_model_settings
4748
from ..tools import ToolDefinition
4849
from ..usage import RequestUsage
@@ -637,41 +638,39 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
637638
return TestModel()
638639

639640
try:
640-
provider, model_name = model.split(':', maxsplit=1)
641+
provider_name, model_name = model.split(':', maxsplit=1)
641642
except ValueError:
642-
provider = None
643+
provider_name = None
643644
model_name = model
644645
if model_name.startswith(('gpt', 'o1', 'o3')):
645-
provider = 'openai'
646+
provider_name = 'openai'
646647
elif model_name.startswith('claude'):
647-
provider = 'anthropic'
648+
provider_name = 'anthropic'
648649
elif model_name.startswith('gemini'):
649-
provider = 'google-gla'
650+
provider_name = 'google-gla'
650651

651-
if provider is not None:
652+
if provider_name is not None:
652653
warnings.warn(
653-
f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{provider}:{model_name}'.",
654+
f"Specifying a model name without a provider prefix is deprecated. Instead of {model_name!r}, use '{provider_name}:{model_name}'.",
654655
DeprecationWarning,
655656
)
656657
else:
657658
raise UserError(f'Unknown model: {model}')
658659

659-
if provider == 'vertexai': # pragma: no cover
660+
if provider_name == 'vertexai': # pragma: no cover
660661
warnings.warn(
661662
"The 'vertexai' provider name is deprecated. Use 'google-vertex' instead.",
662663
DeprecationWarning,
663664
)
664-
provider = 'google-vertex'
665+
provider_name = 'google-vertex'
665666

666-
if provider == 'gateway':
667-
from ..providers.gateway import infer_model as infer_model_from_gateway
667+
provider = infer_provider(provider_name)
668668

669-
return infer_model_from_gateway(model_name)
670-
elif provider == 'cohere':
671-
from .cohere import CohereModel
672-
673-
return CohereModel(model_name, provider=provider)
674-
elif provider in (
669+
model_kind = provider_name
670+
if model_kind.startswith('gateway/'):
671+
model_kind = provider_name.removeprefix('gateway/')
672+
if model_kind in (
673+
'openai',
675674
'azure',
676675
'deepseek',
677676
'cerebras',
@@ -681,43 +680,50 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
681680
'heroku',
682681
'moonshotai',
683682
'ollama',
684-
'openai',
685-
'openai-chat',
686683
'openrouter',
687684
'together',
688685
'vercel',
689686
'litellm',
690687
'nebius',
691688
'ovhcloud',
692689
):
690+
model_kind = 'openai-chat'
691+
elif model_kind in ('google-gla', 'google-vertex'):
692+
model_kind = 'google'
693+
694+
if model_kind == 'openai-chat':
693695
from .openai import OpenAIChatModel
694696

695697
return OpenAIChatModel(model_name, provider=provider)
696-
elif provider == 'openai-responses':
698+
elif model_kind == 'openai-responses':
697699
from .openai import OpenAIResponsesModel
698700

699-
return OpenAIResponsesModel(model_name, provider='openai')
700-
elif provider in ('google-gla', 'google-vertex'):
701+
return OpenAIResponsesModel(model_name, provider=provider)
702+
elif model_kind == 'google':
701703
from .google import GoogleModel
702704

703705
return GoogleModel(model_name, provider=provider)
704-
elif provider == 'groq':
706+
elif model_kind == 'groq':
705707
from .groq import GroqModel
706708

707709
return GroqModel(model_name, provider=provider)
708-
elif provider == 'mistral':
710+
elif model_kind == 'cohere':
711+
from .cohere import CohereModel
712+
713+
return CohereModel(model_name, provider=provider)
714+
elif model_kind == 'mistral':
709715
from .mistral import MistralModel
710716

711717
return MistralModel(model_name, provider=provider)
712-
elif provider == 'anthropic':
718+
elif model_kind == 'anthropic':
713719
from .anthropic import AnthropicModel
714720

715721
return AnthropicModel(model_name, provider=provider)
716-
elif provider == 'bedrock':
722+
elif model_kind == 'bedrock':
717723
from .bedrock import BedrockConverseModel
718724

719725
return BedrockConverseModel(model_name, provider=provider)
720-
elif provider == 'huggingface':
726+
elif model_kind == 'huggingface':
721727
from .huggingface import HuggingFaceModel
722728

723729
return HuggingFaceModel(model_name, provider=provider)

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
self,
163163
model_name: AnthropicModelName,
164164
*,
165-
provider: Literal['anthropic'] | Provider[AsyncAnthropicClient] = 'anthropic',
165+
provider: Literal['anthropic', 'gateway'] | Provider[AsyncAnthropicClient] = 'anthropic',
166166
profile: ModelProfileSpec | None = None,
167167
settings: ModelSettings | None = None,
168168
):
@@ -179,7 +179,7 @@ def __init__(
179179
self._model_name = model_name
180180

181181
if isinstance(provider, str):
182-
provider = infer_provider(provider)
182+
provider = infer_provider('gateway/anthropic' if provider == 'gateway' else provider)
183183
self._provider = provider
184184
self.client = provider.client
185185

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(
207207
self,
208208
model_name: BedrockModelName,
209209
*,
210-
provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
210+
provider: Literal['bedrock', 'gateway'] | Provider[BaseClient] = 'bedrock',
211211
profile: ModelProfileSpec | None = None,
212212
settings: ModelSettings | None = None,
213213
):
@@ -226,7 +226,7 @@ def __init__(
226226
self._model_name = model_name
227227

228228
if isinstance(provider, str):
229-
provider = infer_provider(provider)
229+
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
230230
self._provider = provider
231231
self.client = cast('BedrockRuntimeClient', provider.client)
232232

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
VideoUrl,
3939
)
4040
from ..profiles import ModelProfileSpec
41-
from ..providers import Provider, infer_provider
41+
from ..providers import Provider
4242
from ..settings import ModelSettings
4343
from ..tools import ToolDefinition
4444
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
@@ -131,7 +131,14 @@ def __init__(
131131
self._model_name = model_name
132132

133133
if isinstance(provider, str):
134-
provider = infer_provider(provider)
134+
if provider == 'google-gla':
135+
from pydantic_ai.providers.google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
136+
137+
provider = GoogleGLAProvider() # type: ignore[reportDeprecated]
138+
else:
139+
from pydantic_ai.providers.google_vertex import GoogleVertexProvider # type: ignore[reportDeprecated]
140+
141+
provider = GoogleVertexProvider() # type: ignore[reportDeprecated]
135142
self._provider = provider
136143
self.client = provider.client
137144
self._url = str(self.client.base_url)

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
VideoUrl,
3838
)
3939
from ..profiles import ModelProfileSpec
40-
from ..providers import Provider
40+
from ..providers import Provider, infer_provider
4141
from ..settings import ModelSettings
4242
from ..tools import ToolDefinition
4343
from . import (
@@ -85,8 +85,6 @@
8585
UrlContextDict,
8686
VideoMetadataDict,
8787
)
88-
89-
from ..providers.google import GoogleProvider
9088
except ImportError as _import_error:
9189
raise ImportError(
9290
'Please install `google-genai` to use the Google model, '
@@ -187,7 +185,7 @@ def __init__(
187185
self,
188186
model_name: GoogleModelName,
189187
*,
190-
provider: Literal['google-gla', 'google-vertex'] | Provider[Client] = 'google-gla',
188+
provider: Literal['google-gla', 'google-vertex', 'gateway'] | Provider[Client] = 'google-gla',
191189
profile: ModelProfileSpec | None = None,
192190
settings: ModelSettings | None = None,
193191
):
@@ -196,15 +194,15 @@ def __init__(
196194
Args:
197195
model_name: The name of the model to use.
198196
provider: The provider to use for authentication and API access. Can be either the string
199-
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
200-
If not provided, a new provider will be created using the other parameters.
197+
'google-gla' or 'google-vertex' or an instance of `Provider[google.genai.AsyncClient]`.
198+
Defaults to 'google-gla'.
201199
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
202200
settings: The model settings to use. Defaults to None.
203201
"""
204202
self._model_name = model_name
205203

206204
if isinstance(provider, str):
207-
provider = GoogleProvider(vertexai=provider == 'google-vertex')
205+
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
208206
self._provider = provider
209207
self.client = provider.client
210208

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
self,
142142
model_name: GroqModelName,
143143
*,
144-
provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
144+
provider: Literal['groq', 'gateway'] | Provider[AsyncGroq] = 'groq',
145145
profile: ModelProfileSpec | None = None,
146146
settings: ModelSettings | None = None,
147147
):
@@ -159,7 +159,7 @@ def __init__(
159159
self._model_name = model_name
160160

161161
if isinstance(provider, str):
162-
provider = infer_provider(provider)
162+
provider = infer_provider('gateway/groq' if provider == 'gateway' else provider)
163163
self._provider = provider
164164
self.client = provider.client
165165

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def __init__(
286286
'litellm',
287287
'nebius',
288288
'ovhcloud',
289+
'gateway',
289290
]
290291
| Provider[AsyncOpenAI] = 'openai',
291292
profile: ModelProfileSpec | None = None,
@@ -316,6 +317,7 @@ def __init__(
316317
'litellm',
317318
'nebius',
318319
'ovhcloud',
320+
'gateway',
319321
]
320322
| Provider[AsyncOpenAI] = 'openai',
321323
profile: ModelProfileSpec | None = None,
@@ -345,6 +347,7 @@ def __init__(
345347
'litellm',
346348
'nebius',
347349
'ovhcloud',
350+
'gateway',
348351
]
349352
| Provider[AsyncOpenAI] = 'openai',
350353
profile: ModelProfileSpec | None = None,
@@ -366,7 +369,7 @@ def __init__(
366369
self._model_name = model_name
367370

368371
if isinstance(provider, str):
369-
provider = infer_provider(provider)
372+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
370373
self._provider = provider
371374
self.client = provider.client
372375

@@ -907,7 +910,16 @@ def __init__(
907910
model_name: OpenAIModelName,
908911
*,
909912
provider: Literal[
910-
'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'nebius', 'ovhcloud'
913+
'openai',
914+
'deepseek',
915+
'azure',
916+
'openrouter',
917+
'grok',
918+
'fireworks',
919+
'together',
920+
'nebius',
921+
'ovhcloud',
922+
'gateway',
911923
]
912924
| Provider[AsyncOpenAI] = 'openai',
913925
profile: ModelProfileSpec | None = None,
@@ -924,7 +936,7 @@ def __init__(
924936
self._model_name = model_name
925937

926938
if isinstance(provider, str):
927-
provider = infer_provider(provider)
939+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
928940
self._provider = provider
929941
self.client = provider.client
930942

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import ABC, abstractmethod
99
from typing import Any, Generic, TypeVar
1010

11-
from pydantic_ai import ModelProfile
11+
from ..profiles import ModelProfile
1212

1313
InterfaceClient = TypeVar('InterfaceClient')
1414

@@ -53,7 +53,7 @@ def __repr__(self) -> str:
5353

5454
def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
5555
"""Infers the provider class from the provider name."""
56-
if provider == 'openai':
56+
if provider in ('openai', 'openai-chat', 'openai-responses'):
5757
from .openai import OpenAIProvider
5858

5959
return OpenAIProvider
@@ -73,15 +73,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
7373
from .azure import AzureProvider
7474

7575
return AzureProvider
76-
elif provider == 'google-vertex':
77-
from .google_vertex import GoogleVertexProvider # type: ignore[reportDeprecated]
76+
elif provider in ('google-vertex', 'google-gla'):
77+
from .google import GoogleProvider
7878

79-
return GoogleVertexProvider # type: ignore[reportDeprecated]
80-
elif provider == 'google-gla':
81-
from .google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
82-
83-
return GoogleGLAProvider # type: ignore[reportDeprecated]
84-
# NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
79+
return GoogleProvider
8580
elif provider == 'bedrock':
8681
from .bedrock import BedrockProvider
8782

@@ -156,5 +151,15 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
156151

157152
def infer_provider(provider: str) -> Provider[Any]:
158153
"""Infer the provider from the provider name."""
159-
provider_class = infer_provider_class(provider)
160-
return provider_class()
154+
if provider.startswith('gateway/'):
155+
from .gateway import gateway_provider
156+
157+
provider = provider.removeprefix('gateway/')
158+
return gateway_provider(provider)
159+
elif provider in ('google-vertex', 'google-gla'):
160+
from .google import GoogleProvider
161+
162+
return GoogleProvider(vertexai=provider == 'google-vertex')
163+
else:
164+
provider_class = infer_provider_class(provider)
165+
return provider_class()

0 commit comments

Comments
 (0)