Skip to content

Commit 3256386

Browse files
slkoo-ccDouweM
andauthored
Allow custom provider factory to be passed into infer_model (#3341)
Co-authored-by: Douwe Maan <[email protected]>
1 parent a8f7067 commit 3256386

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import base64
1010
import warnings
1111
from abc import ABC, abstractmethod
12-
from collections.abc import AsyncIterator, Iterator
12+
from collections.abc import AsyncIterator, Callable, Iterator
1313
from contextlib import asynccontextmanager, contextmanager
1414
from dataclasses import dataclass, field, replace
1515
from datetime import datetime
@@ -47,7 +47,7 @@
4747
)
4848
from ..output import OutputMode
4949
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
50-
from ..providers import infer_provider
50+
from ..providers import Provider, infer_provider
5151
from ..settings import ModelSettings, merge_model_settings
5252
from ..tools import ToolDefinition
5353
from ..usage import RequestUsage
@@ -724,8 +724,17 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]:
724724
ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition]
725725

726726

727-
def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
728-
"""Infer the model from the name."""
727+
def infer_model( # noqa: C901
728+
model: Model | KnownModelName | str, provider_factory: Callable[[str], Provider[Any]] = infer_provider
729+
) -> Model:
730+
"""Infer the model from the name.
731+
732+
Args:
733+
model:
734+
Model name to instantiate, in the format of `provider:model`. Use the string "test" to instantiate TestModel.
735+
provider_factory:
736+
Function that instantiates a provider object. The provider name is passed into the function parameter. Defaults to `provider.infer_provider`.
737+
"""
729738
if isinstance(model, Model):
730739
return model
731740
elif model == 'test':
@@ -760,7 +769,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
760769
)
761770
provider_name = 'google-vertex'
762771

763-
provider = infer_provider(provider_name)
772+
provider: Provider[Any] = provider_factory(provider_name)
764773

765774
model_kind = provider_name
766775
if model_kind.startswith('gateway/'):

tests/models/test_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,17 @@ def test_infer_model(
242242
assert m2 is m
243243

244244

245+
def test_infer_model_with_provider():
246+
from pydantic_ai.providers import openai
247+
248+
provider_class = openai.OpenAIProvider(api_key='1234', base_url='http://test')
249+
m = infer_model('openai:gpt-5', lambda x: provider_class)
250+
251+
assert isinstance(m, OpenAIChatModel)
252+
assert m._provider is provider_class # type: ignore
253+
assert m._provider.base_url == 'http://test' # type: ignore
254+
255+
245256
def test_infer_str_unknown():
246257
with pytest.raises(UserError, match='Unknown model: foobar'):
247258
infer_model('foobar')

0 commit comments

Comments
 (0)