|
9 | 9 | import base64 |
10 | 10 | import warnings |
11 | 11 | from abc import ABC, abstractmethod |
12 | | -from collections.abc import AsyncIterator, Iterator |
| 12 | +from collections.abc import AsyncIterator, Callable, Iterator |
13 | 13 | from contextlib import asynccontextmanager, contextmanager |
14 | 14 | from dataclasses import dataclass, field, replace |
15 | 15 | from datetime import datetime |
16 | 16 | from functools import cache, cached_property |
17 | | -from typing import Any, Generic, Literal, TypeVar, Callable, overload |
| 17 | +from typing import Any, Generic, Literal, TypeVar, overload |
18 | 18 |
|
19 | 19 | import httpx |
20 | 20 | from typing_extensions import TypeAliasType, TypedDict |
|
47 | 47 | ) |
48 | 48 | from ..output import OutputMode |
49 | 49 | from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec |
50 | | -from ..providers import infer_provider, Provider |
| 50 | +from ..providers import Provider, infer_provider |
51 | 51 | from ..settings import ModelSettings, merge_model_settings |
52 | 52 | from ..tools import ToolDefinition |
53 | 53 | from ..usage import RequestUsage |
@@ -677,7 +677,9 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: |
677 | 677 | ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] |
678 | 678 |
|
679 | 679 |
|
680 | | -def infer_model(model: Model | KnownModelName | str, provider_generator: Callable[[str], Provider[Any]] | None = None) -> Model: # noqa: C901 |
| 680 | +def infer_model( |
| 681 | + model: Model | KnownModelName | str, provider_generator: Callable[[str], Provider[Any]] | None = None |
| 682 | +) -> Model: # noqa: C901 |
681 | 683 | """Infer the model from the name. May optionally pass a callable that setup a custom provider for the model.""" |
682 | 684 | if isinstance(model, Model): |
683 | 685 | return model |
@@ -714,7 +716,7 @@ def infer_model(model: Model | KnownModelName | str, provider_generator: Callabl |
714 | 716 | provider_name = 'google-vertex' |
715 | 717 |
|
716 | 718 | if provider_generator is None: |
717 | | - provider_generator = infer_provider |
| 719 | + provider_generator = infer_provider |
718 | 720 | provider = provider_generator(provider_name) |
719 | 721 |
|
720 | 722 | model_kind = provider_name |
|
0 commit comments