Skip to content

Commit d29d5e0

Browse files
authored
Support VertexAI models in infer_model (#89)
1 parent 3899c32 commit d29d5e0

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

pydantic_ai/models/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import httpx
1717

18+
from ..exceptions import UserError
1819
from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
1920

2021
if TYPE_CHECKING:
@@ -46,6 +47,8 @@
4647
'groq:gemma-7b-it',
4748
'gemini-1.5-flash',
4849
'gemini-1.5-pro',
50+
'vertexai:gemini-1.5-flash',
51+
'vertexai:gemini-1.5-pro',
4952
'test',
5053
]
5154
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -245,9 +248,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
245248
from .groq import GroqModel
246249

247250
return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
248-
else:
249-
from ..exceptions import UserError
251+
elif model.startswith('vertexai:'):
252+
from .vertexai import VertexAIModel
250253

254+
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
255+
else:
251256
raise UserError(f'Unknown model: {model}')
252257

253258

pydantic_ai_examples/pydantic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MyModel(BaseModel):
2424

2525

2626
model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o'))
27+
print(f'Using model: {model}')
2728
agent = Agent(model, result_type=MyModel)
2829

2930
if __name__ == '__main__':

tests/models/test_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import TYPE_CHECKING
2+
13
import pytest
24

35
from pydantic_ai import UserError
@@ -6,6 +8,19 @@
68
from pydantic_ai.models.openai import OpenAIModel
79
from tests.conftest import TestEnv
810

11+
if TYPE_CHECKING:
12+
from pydantic_ai.models.vertexai import VertexAIModel
13+
14+
google_auth_installed = True
15+
16+
else:
17+
try:
18+
from pydantic_ai.models.vertexai import VertexAIModel
19+
except ImportError:
20+
google_auth_installed = False
21+
else:
22+
google_auth_installed = True
23+
924

1025
def test_infer_str_openai(env: TestEnv):
1126
env.set('OPENAI_API_KEY', 'via-env-var')
@@ -24,6 +39,13 @@ def test_infer_str_gemini(env: TestEnv):
2439
assert m.name() == 'gemini-1.5-flash'
2540

2641

42+
@pytest.mark.skipif(not google_auth_installed, reason='google-auth not installed')
43+
def test_infer_vertexai(env: TestEnv):
44+
m = infer_model('vertexai:gemini-1.5-flash')
45+
assert isinstance(m, VertexAIModel)
46+
assert m.name() == 'vertexai:gemini-1.5-flash'
47+
48+
2749
def test_infer_str_unknown():
2850
with pytest.raises(UserError, match='Unknown model: foobar'):
2951
infer_model('foobar') # pyright: ignore[reportArgumentType]

0 commit comments

Comments
 (0)