1717 from groq import AsyncGroq
1818 from openai import AsyncOpenAI
1919
20+ from pydantic_ai .models import Model
2021 from pydantic_ai .models .anthropic import AsyncAnthropicClient
2122 from pydantic_ai .providers import Provider
2223
2526
2627@overload
2728def gateway_provider (
28- upstream_provider : Literal ['openai' , 'openai-chat' , 'openai-responses' ],
29+ api_type : Literal ['chat' , 'responses' ],
30+ / ,
2931 * ,
32+ routing_group : str | None = None ,
33+ profile : str | None = None ,
3034 api_key : str | None = None ,
3135 base_url : str | None = None ,
3236 http_client : httpx .AsyncClient | None = None ,
@@ -35,8 +39,11 @@ def gateway_provider(
3539
3640@overload
3741def gateway_provider (
38- upstream_provider : Literal ['groq' ],
42+ api_type : Literal ['groq' ],
43+ / ,
3944 * ,
45+ routing_group : str | None = None ,
46+ profile : str | None = None ,
4047 api_key : str | None = None ,
4148 base_url : str | None = None ,
4249 http_client : httpx .AsyncClient | None = None ,
@@ -45,56 +52,77 @@ def gateway_provider(
4552
4653@overload
4754def gateway_provider (
48- upstream_provider : Literal ['google-vertex' ],
55+ api_type : Literal ['anthropic' ],
56+ / ,
4957 * ,
58+ routing_group : str | None = None ,
59+ profile : str | None = None ,
5060 api_key : str | None = None ,
5161 base_url : str | None = None ,
52- ) -> Provider [GoogleClient ]: ...
62+ http_client : httpx .AsyncClient | None = None ,
63+ ) -> Provider [AsyncAnthropicClient ]: ...
5364
5465
5566@overload
5667def gateway_provider (
57- upstream_provider : Literal ['anthropic' ],
68+ api_type : Literal ['converse' ],
69+ / ,
5870 * ,
71+ routing_group : str | None = None ,
72+ profile : str | None = None ,
5973 api_key : str | None = None ,
6074 base_url : str | None = None ,
61- ) -> Provider [AsyncAnthropicClient ]: ...
75+ ) -> Provider [BaseClient ]: ...
6276
6377
6478@overload
6579def gateway_provider (
66- upstream_provider : Literal ['bedrock' ],
80+ api_type : Literal ['gemini' ],
81+ / ,
6782 * ,
83+ routing_group : str | None = None ,
84+ profile : str | None = None ,
6885 api_key : str | None = None ,
6986 base_url : str | None = None ,
70- ) -> Provider [BaseClient ]: ...
87+ http_client : httpx .AsyncClient | None = None ,
88+ ) -> Provider [GoogleClient ]: ...
7189
7290
7391@overload
7492def gateway_provider (
75- upstream_provider : str ,
93+ api_type : str ,
94+ / ,
7695 * ,
96+ routing_group : str | None = None ,
97+ profile : str | None = None ,
7798 api_key : str | None = None ,
7899 base_url : str | None = None ,
79100) -> Provider [Any ]: ...
80101
81102
82- UpstreamProvider = Literal ['openai' , 'openai- chat' , 'openai- responses' , 'groq ' , 'google-vertex ' , 'anthropic' , 'bedrock ' ]
103+ APIType = Literal ['chat' , 'responses' , 'gemini ' , 'converse ' , 'anthropic' , 'groq ' ]
83104
84105
85106def gateway_provider (
86- upstream_provider : UpstreamProvider | str ,
107+ api_type : APIType | str ,
108+ / ,
87109 * ,
88110 # Every provider
111+ routing_group : str | None = None ,
112+ profile : str | None = None ,
89113 api_key : str | None = None ,
90114 base_url : str | None = None ,
91- # OpenAI, Groq & Anthropic
115+ # OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
92116 http_client : httpx .AsyncClient | None = None ,
93117) -> Provider [Any ]:
94118 """Create a new Gateway provider.
95119
96120 Args:
97- upstream_provider: The upstream provider to use.
121+ api_type: Determines the API type to use.
122+ routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
123+ any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
124+ profile: A provider may have a profile, which is a unique identifier for the provider.
125+ The `pydantic-ai-gateway-profile` header will be added.
98126 api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
99127 environment variable will be used if available.
100128 base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -109,18 +137,24 @@ def gateway_provider(
109137 )
110138
111139 base_url = base_url or os .getenv ('PYDANTIC_AI_GATEWAY_BASE_URL' , GATEWAY_BASE_URL )
112- http_client = http_client or cached_async_http_client (provider = f'gateway/{ upstream_provider } ' )
140+ http_client = http_client or cached_async_http_client (provider = f'gateway/{ api_type } ' )
113141 http_client .event_hooks = {'request' : [_request_hook (api_key )]}
114142
115- if upstream_provider in ('openai' , 'openai-chat' , 'openai-responses' ):
143+ if profile is not None :
144+ http_client .headers .setdefault ('pydantic-ai-gateway-profile' , profile )
145+
146+ if routing_group is not None :
147+ http_client .headers .setdefault ('pydantic-ai-gateway-routing-group' , routing_group )
148+
149+ if api_type in ('chat' , 'responses' ):
116150 from .openai import OpenAIProvider
117151
118- return OpenAIProvider (api_key = api_key , base_url = _merge_url_path (base_url , 'openai' ), http_client = http_client )
119- elif upstream_provider == 'groq' :
152+ return OpenAIProvider (api_key = api_key , base_url = _merge_url_path (base_url , api_type ), http_client = http_client )
153+ elif api_type == 'groq' :
120154 from .groq import GroqProvider
121155
122156 return GroqProvider (api_key = api_key , base_url = _merge_url_path (base_url , 'groq' ), http_client = http_client )
123- elif upstream_provider == 'anthropic' :
157+ elif api_type == 'anthropic' :
124158 from anthropic import AsyncAnthropic
125159
126160 from .anthropic import AnthropicProvider
@@ -132,25 +166,25 @@ def gateway_provider(
132166 http_client = http_client ,
133167 )
134168 )
135- elif upstream_provider == 'bedrock ' :
169+ elif api_type == 'converse ' :
136170 from .bedrock import BedrockProvider
137171
138172 return BedrockProvider (
139173 api_key = api_key ,
140- base_url = _merge_url_path (base_url , 'bedrock' ),
174+ base_url = _merge_url_path (base_url , api_type ),
141175 region_name = 'pydantic-ai-gateway' , # Fake region name to avoid NoRegionError
142176 )
143- elif upstream_provider == 'google-vertex ' :
177+ elif api_type == 'gemini ' :
144178 from .google import GoogleProvider
145179
146180 return GoogleProvider (
147181 vertexai = True ,
148182 api_key = api_key ,
149- base_url = _merge_url_path (base_url , 'google-vertex ' ),
183+ base_url = _merge_url_path (base_url , 'gemini ' ),
150184 http_client = http_client ,
151185 )
152186 else :
153- raise UserError (f'Unknown upstream provider : { upstream_provider } ' )
187+ raise UserError (f'Unknown API type : { api_type } ' )
154188
155189
156190def _request_hook (api_key : str ) -> Callable [[httpx .Request ], Awaitable [httpx .Request ]]:
@@ -182,3 +216,33 @@ def _merge_url_path(base_url: str, path: str) -> str:
182216 path: The path to merge.
183217 """
184218 return base_url .rstrip ('/' ) + '/' + path .lstrip ('/' )
219+
220+
221+ def infer_gateway_model (api_type : APIType | str , * , model_name : str ) -> Model :
222+ """Infer the model class for a given API type."""
223+ if api_type == 'chat' :
224+ from pydantic_ai .models .openai import OpenAIChatModel
225+
226+ return OpenAIChatModel (model_name = model_name , provider = 'gateway' )
227+ elif api_type == 'groq' :
228+ from pydantic_ai .models .groq import GroqModel
229+
230+ return GroqModel (model_name = model_name , provider = 'gateway' )
231+ elif api_type == 'responses' :
232+ from pydantic_ai .models .openai import OpenAIResponsesModel
233+
234+ return OpenAIResponsesModel (model_name = model_name , provider = 'gateway' )
235+ elif api_type == 'gemini' :
236+ from pydantic_ai .models .google import GoogleModel
237+
238+ return GoogleModel (model_name = model_name , provider = 'gateway' )
239+ elif api_type == 'converse' :
240+ from pydantic_ai .models .bedrock import BedrockConverseModel
241+
242+ return BedrockConverseModel (model_name = model_name , provider = 'gateway' )
243+ elif api_type == 'anthropic' :
244+ from pydantic_ai .models .anthropic import AnthropicModel
245+
246+ return AnthropicModel (model_name = model_name , provider = 'gateway' )
247+ else :
248+ raise ValueError (f'Unknown API type: { api_type } ' ) # pragma: no cover
0 commit comments