Skip to content

Commit e328fc3

Browse files
authored
Use provider on models inference (#1234)
1 parent f17f276 commit e328fc3

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dataclasses import dataclass, field
1313
from datetime import datetime
1414
from functools import cache
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, cast
1616

1717
import httpx
1818
from typing_extensions import Literal
@@ -383,6 +383,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
383383

384384
try:
385385
provider, model_name = model.split(':', maxsplit=1)
386+
provider = cast(str, provider)
386387
except ValueError:
387388
model_name = model
388389
# TODO(Marcelo): We should deprecate this way.
@@ -414,22 +415,19 @@ def infer_model(model: Model | KnownModelName) -> Model:
414415
elif provider == 'groq':
415416
from .groq import GroqModel
416417

417-
# TODO(Marcelo): Missing provider API.
418-
return GroqModel(model_name)
418+
return GroqModel(model_name, provider=provider)
419419
elif provider == 'mistral':
420420
from .mistral import MistralModel
421421

422-
# TODO(Marcelo): Missing provider API.
423-
return MistralModel(model_name)
422+
return MistralModel(model_name, provider=provider)
424423
elif provider == 'anthropic':
425424
from .anthropic import AnthropicModel
426425

427-
# TODO(Marcelo): Missing provider API.
428-
return AnthropicModel(model_name)
426+
return AnthropicModel(model_name, provider=provider)
429427
elif provider == 'bedrock':
430428
from .bedrock import BedrockConverseModel
431429

432-
return BedrockConverseModel(model_name)
430+
return BedrockConverseModel(model_name, provider=provider)
433431
else:
434432
raise UserError(f'Unknown model: {model}')
435433

tests/test_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def test_docs_examples( # noqa: C901
8686
env.set('GEMINI_API_KEY', 'testing')
8787
env.set('GROQ_API_KEY', 'testing')
8888
env.set('CO_API_KEY', 'testing')
89+
env.set('MISTRAL_API_KEY', 'testing')
90+
env.set('ANTHROPIC_API_KEY', 'testing')
8991

9092
sys.path.append('tests/example_modules')
9193

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)