Skip to content

Commit 46c4441

Browse files
committed
fix tests
1 parent 68682ef commit 46c4441

File tree

6 files changed

+46
-13
lines changed

6 files changed

+46
-13
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
764764

765765
model_kind = provider_name
766766
if model_kind.startswith('gateway/'):
767-
model_kind = provider_name.removeprefix('gateway/')
767+
from ..providers.gateway import infer_gateway_model
768+
769+
return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
768770
if model_kind in (
769771
'openai',
770772
'azure',

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
158158
if provider.startswith('gateway/'):
159159
from .gateway import gateway_provider
160160

161-
provider = provider.removeprefix('gateway/')
162-
return gateway_provider(provider)
161+
api_type = provider.removeprefix('gateway/')
162+
return gateway_provider(api_type)
163163
elif provider in ('google-vertex', 'google-gla'):
164164
from .google import GoogleProvider
165165

pydantic_ai_slim/pydantic_ai/providers/gateway.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from groq import AsyncGroq
1818
from openai import AsyncOpenAI
1919

20+
from pydantic_ai.models import Model
2021
from pydantic_ai.models.anthropic import AsyncAnthropicClient
2122
from pydantic_ai.providers import Provider
2223

@@ -190,3 +191,33 @@ def _merge_url_path(base_url: str, path: str) -> str:
190191
path: The path to merge.
191192
"""
192193
return base_url.rstrip('/') + '/' + path.lstrip('/')
194+
195+
196+
def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
197+
"""Infer the model class for a given API type."""
198+
if api_type == 'chat':
199+
from pydantic_ai.models.openai import OpenAIChatModel
200+
201+
return OpenAIChatModel(model_name=model_name, provider='gateway')
202+
elif api_type == 'groq':
203+
from pydantic_ai.models.groq import GroqModel
204+
205+
return GroqModel(model_name=model_name, provider='gateway')
206+
elif api_type == 'responses':
207+
from pydantic_ai.models.openai import OpenAIResponsesModel
208+
209+
return OpenAIResponsesModel(model_name=model_name, provider='gateway')
210+
elif api_type == 'gemini':
211+
from pydantic_ai.models.google import GoogleModel
212+
213+
return GoogleModel(model_name=model_name, provider='gateway')
214+
elif api_type == 'converse':
215+
from pydantic_ai.models.bedrock import BedrockConverseModel
216+
217+
return BedrockConverseModel(model_name=model_name, provider='gateway')
218+
elif api_type == 'anthropic':
219+
from pydantic_ai.models.anthropic import AnthropicModel
220+
221+
return AnthropicModel(model_name=model_name, provider='gateway')
222+
else:
223+
raise ValueError(f'Unknown API type: {api_type}')

tests/models/cassettes/test_model_names/test_known_model_names.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,6 @@ interactions:
133133
- max-age=3600; includeSubDomains
134134
parsed_body:
135135
data:
136-
- created: 0
137-
id: gpt-oss-120b
138-
object: model
139-
owned_by: Cerebras
140136
- created: 0
141137
id: qwen-3-235b-a22b-thinking-2507
142138
object: model
@@ -149,16 +145,20 @@ interactions:
149145
id: qwen-3-235b-a22b-instruct-2507
150146
object: model
151147
owned_by: Cerebras
148+
- created: 0
149+
id: qwen-3-32b
150+
object: model
151+
owned_by: Cerebras
152152
- created: 0
153153
id: zai-glm-4.6
154154
object: model
155155
owned_by: Cerebras
156156
- created: 0
157-
id: llama3.1-8b
157+
id: gpt-oss-120b
158158
object: model
159159
owned_by: Cerebras
160160
- created: 0
161-
id: qwen-3-32b
161+
id: llama3.1-8b
162162
object: model
163163
owned_by: Cerebras
164164
object: list

tests/models/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@
7474
),
7575
pytest.param(
7676
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
77-
'gateway/bedrock:amazon.nova-micro-v1:0',
77+
'gateway/converse:amazon.nova-micro-v1:0',
7878
'amazon.nova-micro-v1:0',
7979
'bedrock',
8080
'bedrock',
8181
BedrockConverseModel,
82-
id='gateway/bedrock:amazon.nova-micro-v1:0',
82+
id='gateway/converse:amazon.nova-micro-v1:0',
8383
),
8484
pytest.param(
8585
{'OPENAI_API_KEY': 'openai-api-key'},

tests/providers/test_provider_names.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@
6262
('ovhcloud', OVHcloudProvider, 'OVHCLOUD_API_KEY'),
6363
('gateway/openai', OpenAIProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'),
6464
('gateway/groq', GroqProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'),
65-
('gateway/google-vertex', GoogleProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'),
65+
('gateway/gemini', GoogleProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'),
6666
('gateway/anthropic', AnthropicProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'),
67-
('gateway/bedrock', BedrockProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'),
67+
('gateway/converse', BedrockProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'),
6868
('outlines', OutlinesProvider, None),
6969
]
7070

0 commit comments

Comments
 (0)