|
17 | 17 | from groq import AsyncGroq |
18 | 18 | from openai import AsyncOpenAI |
19 | 19 |
|
| 20 | + from pydantic_ai.models import Model |
20 | 21 | from pydantic_ai.models.anthropic import AsyncAnthropicClient |
21 | 22 | from pydantic_ai.providers import Provider |
22 | 23 |
|
@@ -190,3 +191,33 @@ def _merge_url_path(base_url: str, path: str) -> str: |
190 | 191 | path: The path to merge. |
191 | 192 | """ |
192 | 193 | return base_url.rstrip('/') + '/' + path.lstrip('/') |
| 194 | + |
| 195 | + |
| 196 | +def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model: |
| 197 | + """Infer the model class for a given API type.""" |
| 198 | + if api_type == 'chat': |
| 199 | + from pydantic_ai.models.openai import OpenAIChatModel |
| 200 | + |
| 201 | + return OpenAIChatModel(model_name=model_name, provider='gateway') |
| 202 | + elif api_type == 'groq': |
| 203 | + from pydantic_ai.models.groq import GroqModel |
| 204 | + |
| 205 | + return GroqModel(model_name=model_name, provider='gateway') |
| 206 | + elif api_type == 'responses': |
| 207 | + from pydantic_ai.models.openai import OpenAIResponsesModel |
| 208 | + |
| 209 | + return OpenAIResponsesModel(model_name=model_name, provider='gateway') |
| 210 | + elif api_type == 'gemini': |
| 211 | + from pydantic_ai.models.google import GoogleModel |
| 212 | + |
| 213 | + return GoogleModel(model_name=model_name, provider='gateway') |
| 214 | + elif api_type == 'converse': |
| 215 | + from pydantic_ai.models.bedrock import BedrockConverseModel |
| 216 | + |
| 217 | + return BedrockConverseModel(model_name=model_name, provider='gateway') |
| 218 | + elif api_type == 'anthropic': |
| 219 | + from pydantic_ai.models.anthropic import AnthropicModel |
| 220 | + |
| 221 | + return AnthropicModel(model_name=model_name, provider='gateway') |
| 222 | + else: |
| 223 | + raise ValueError(f'Unknown API type: {api_type}') |
0 commit comments