|
9 | 9 | import base64 |
10 | 10 | import warnings |
11 | 11 | from abc import ABC, abstractmethod |
12 | | -from collections.abc import AsyncIterator, Iterator |
| 12 | +from collections.abc import AsyncIterator, Callable, Iterator |
13 | 13 | from contextlib import asynccontextmanager, contextmanager |
14 | 14 | from dataclasses import dataclass, field, replace |
15 | 15 | from datetime import datetime |
|
47 | 47 | ) |
48 | 48 | from ..output import OutputMode |
49 | 49 | from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec |
50 | | -from ..providers import infer_provider |
| 50 | +from ..providers import Provider, infer_provider |
51 | 51 | from ..settings import ModelSettings, merge_model_settings |
52 | 52 | from ..tools import ToolDefinition |
53 | 53 | from ..usage import RequestUsage |
|
126 | 126 | 'cerebras:gpt-oss-120b', |
127 | 127 | 'cerebras:llama3.1-8b', |
128 | 128 | 'cerebras:llama-3.3-70b', |
129 | | - 'cerebras:llama-4-scout-17b-16e-instruct', |
130 | | - 'cerebras:llama-4-maverick-17b-128e-instruct', |
131 | 129 | 'cerebras:qwen-3-235b-a22b-instruct-2507', |
132 | 130 | 'cerebras:qwen-3-32b', |
133 | | - 'cerebras:qwen-3-coder-480b', |
134 | 131 | 'cerebras:qwen-3-235b-a22b-thinking-2507', |
135 | 132 | 'cohere:c4ai-aya-expanse-32b', |
136 | 133 | 'cohere:c4ai-aya-expanse-8b', |
137 | 134 | 'cohere:command-nightly', |
138 | 135 | 'cohere:command-r-08-2024', |
139 | 136 | 'cohere:command-r-plus-08-2024', |
140 | 137 | 'cohere:command-r7b-12-2024', |
| 138 | + 'cerebras:zai-glm-4.6', |
141 | 139 | 'deepseek:deepseek-chat', |
142 | 140 | 'deepseek:deepseek-reasoner', |
143 | 141 | 'google-gla:gemini-2.0-flash', |
|
189 | 187 | 'groq:llama-3.2-3b-preview', |
190 | 188 | 'groq:llama-3.2-11b-vision-preview', |
191 | 189 | 'groq:llama-3.2-90b-vision-preview', |
| 190 | + 'heroku:amazon-rerank-1-0', |
192 | 191 | 'heroku:claude-3-5-haiku', |
193 | 192 | 'heroku:claude-3-5-sonnet-latest', |
194 | 193 | 'heroku:claude-3-7-sonnet', |
195 | | - 'heroku:claude-4-sonnet', |
196 | 194 | 'heroku:claude-3-haiku', |
| 195 | + 'heroku:claude-4-5-haiku', |
| 196 | + 'heroku:claude-4-5-sonnet', |
| 197 | + 'heroku:claude-4-sonnet', |
| 198 | + 'heroku:cohere-rerank-3-5', |
197 | 199 | 'heroku:gpt-oss-120b', |
198 | 200 | 'heroku:nova-lite', |
199 | 201 | 'heroku:nova-pro', |
@@ -722,8 +724,17 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: |
722 | 724 | ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] |
723 | 725 |
|
724 | 726 |
|
725 | | -def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 |
726 | | - """Infer the model from the name.""" |
| 727 | +def infer_model( # noqa: C901 |
| 728 | + model: Model | KnownModelName | str, provider_factory: Callable[[str], Provider[Any]] = infer_provider |
| 729 | +) -> Model: |
| 730 | + """Infer the model from the name. |
| 731 | +
|
| 732 | + Args: |
| 733 | + model: |
| 734 | + Model name to instantiate, in the format of `provider:model`. Use the string "test" to instantiate TestModel. |
| 735 | + provider_factory: |
| 736 | + Function that instantiates a provider object. The provider name is passed into the function parameter. Defaults to `provider.infer_provider`. |
| 737 | + """ |
727 | 738 | if isinstance(model, Model): |
728 | 739 | return model |
729 | 740 | elif model == 'test': |
@@ -758,11 +769,13 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 |
758 | 769 | ) |
759 | 770 | provider_name = 'google-vertex' |
760 | 771 |
|
761 | | - provider = infer_provider(provider_name) |
| 772 | + provider: Provider[Any] = provider_factory(provider_name) |
762 | 773 |
|
763 | 774 | model_kind = provider_name |
764 | 775 | if model_kind.startswith('gateway/'): |
765 | | - model_kind = provider_name.removeprefix('gateway/') |
| 776 | + from ..providers.gateway import infer_gateway_model |
| 777 | + |
| 778 | + return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name) |
766 | 779 | if model_kind in ( |
767 | 780 | 'openai', |
768 | 781 | 'azure', |
|
0 commit comments