|
9 | 9 | import base64 |
10 | 10 | import warnings |
11 | 11 | from abc import ABC, abstractmethod |
12 | | -from collections.abc import AsyncIterator, Iterator |
| 12 | +from collections.abc import AsyncIterator, Callable, Iterator |
13 | 13 | from contextlib import asynccontextmanager, contextmanager |
14 | 14 | from dataclasses import dataclass, field, replace |
15 | 15 | from datetime import datetime |
|
47 | 47 | ) |
48 | 48 | from ..output import OutputMode |
49 | 49 | from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec |
50 | | -from ..providers import infer_provider |
| 50 | +from ..providers import Provider, infer_provider |
51 | 51 | from ..settings import ModelSettings, merge_model_settings |
52 | 52 | from ..tools import ToolDefinition |
53 | 53 | from ..usage import RequestUsage |
@@ -724,8 +724,17 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: |
724 | 724 | ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] |
725 | 725 |
|
726 | 726 |
|
727 | | -def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 |
728 | | - """Infer the model from the name.""" |
| 727 | +def infer_model( # noqa: C901 |
| 728 | + model: Model | KnownModelName | str, provider_factory: Callable[[str], Provider[Any]] = infer_provider |
| 729 | +) -> Model: |
| 730 | + """Infer the model from the name. |
| 731 | +
|
| 732 | + Args: |
| 733 | + model: |
| 734 | + Model name to instantiate, in the format of `provider:model`. Use the string "test" to instantiate TestModel. |
| 735 | + provider_factory: |
| 736 | + Function that instantiates a provider object. The provider name is passed into the function parameter. Defaults to `provider.infer_provider`. |
| 737 | + """ |
729 | 738 | if isinstance(model, Model): |
730 | 739 | return model |
731 | 740 | elif model == 'test': |
@@ -760,7 +769,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 |
760 | 769 | ) |
761 | 770 | provider_name = 'google-vertex' |
762 | 771 |
|
763 | | - provider = infer_provider(provider_name) |
| 772 | + provider: Provider[Any] = provider_factory(provider_name) |
764 | 773 |
|
765 | 774 | model_kind = provider_name |
766 | 775 | if model_kind.startswith('gateway/'): |
|
0 commit comments