File tree Expand file tree Collapse file tree 3 files changed +30
-2
lines changed Expand file tree Collapse file tree 3 files changed +30
-2
lines changed Original file line number Diff line number Diff line change 15
15
16
16
import httpx
17
17
18
+ from ..exceptions import UserError
18
19
from ..messages import Message , ModelAnyResponse , ModelStructuredResponse
19
20
20
21
if TYPE_CHECKING :
46
47
'groq:gemma-7b-it' ,
47
48
'gemini-1.5-flash' ,
48
49
'gemini-1.5-pro' ,
50
+ 'vertexai:gemini-1.5-flash' ,
51
+ 'vertexai:gemini-1.5-pro' ,
49
52
'test' ,
50
53
]
51
54
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -245,9 +248,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
245
248
from .groq import GroqModel
246
249
247
250
return GroqModel (model [5 :]) # pyright: ignore[reportArgumentType]
248
- else :
249
- from .. exceptions import UserError
251
+ elif model . startswith ( 'vertexai:' ) :
252
+ from .vertexai import VertexAIModel
250
253
254
+ return VertexAIModel (model [9 :]) # pyright: ignore[reportArgumentType]
255
+ else :
251
256
raise UserError (f'Unknown model: { model } ' )
252
257
253
258
Original file line number Diff line number Diff line change @@ -24,6 +24,7 @@ class MyModel(BaseModel):
24
24
25
25
26
26
model = cast (KnownModelName , os .getenv ('PYDANTIC_AI_MODEL' , 'openai:gpt-4o' ))
27
+ print (f'Using model: { model } ' )
27
28
agent = Agent (model , result_type = MyModel )
28
29
29
30
if __name__ == '__main__' :
Original file line number Diff line number Diff line change
1
+ from typing import TYPE_CHECKING
2
+
1
3
import pytest
2
4
3
5
from pydantic_ai import UserError
6
8
from pydantic_ai .models .openai import OpenAIModel
7
9
from tests .conftest import TestEnv
8
10
11
+ if TYPE_CHECKING :
12
+ from pydantic_ai .models .vertexai import VertexAIModel
13
+
14
+ google_auth_installed = True
15
+
16
+ else :
17
+ try :
18
+ from pydantic_ai .models .vertexai import VertexAIModel
19
+ except ImportError :
20
+ google_auth_installed = False
21
+ else :
22
+ google_auth_installed = True
23
+
9
24
10
25
def test_infer_str_openai (env : TestEnv ):
11
26
env .set ('OPENAI_API_KEY' , 'via-env-var' )
@@ -24,6 +39,13 @@ def test_infer_str_gemini(env: TestEnv):
24
39
assert m .name () == 'gemini-1.5-flash'
25
40
26
41
42
+ @pytest .mark .skipif (not google_auth_installed , reason = 'google-auth not installed' )
43
+ def test_infer_vertexai (env : TestEnv ):
44
+ m = infer_model ('vertexai:gemini-1.5-flash' )
45
+ assert isinstance (m , VertexAIModel )
46
+ assert m .name () == 'vertexai:gemini-1.5-flash'
47
+
48
+
27
49
def test_infer_str_unknown ():
28
50
with pytest .raises (UserError , match = 'Unknown model: foobar' ):
29
51
infer_model ('foobar' ) # pyright: ignore[reportArgumentType]
You can’t perform that action at this time.
0 commit comments