1
1
import sys
2
+ import types
2
3
from io import StringIO
3
- from typing import Any
4
+ from typing import Any , Callable
4
5
5
6
import pytest
6
7
from dirty_equals import IsInstance , IsStr
16
17
from .conftest import TestEnv , try_import
17
18
18
19
with try_import () as imports_successful :
20
+ from openai import OpenAIError
19
21
from prompt_toolkit .input import create_pipe_input
20
22
from prompt_toolkit .output import DummyOutput
21
23
from prompt_toolkit .shortcuts import PromptSession
22
24
23
25
from pydantic_ai ._cli import cli , cli_agent , handle_slash_command
26
+ from pydantic_ai .models .openai import OpenAIModel
24
27
25
28
pytestmark = pytest .mark .skipif (not imports_successful (), reason = 'install cli extras to run cli tests' )
26
29
@@ -32,70 +35,90 @@ def test_cli_version(capfd: CaptureFixture[str]):
32
35
33
36
def test_invalid_model (capfd : CaptureFixture [str ]):
34
37
assert cli (['--model' , 'potato' ]) == 1
35
- assert capfd .readouterr ().out .splitlines () == snapshot (
36
- [IsStr (), 'Error initializing potato:' , 'Unknown model: potato' ]
37
- )
38
+ assert capfd .readouterr ().out .splitlines () == snapshot (['Error initializing potato:' , 'Unknown model: potato' ])
38
39
39
40
40
- def test_agent_flag (capfd : CaptureFixture [str ], mocker : MockerFixture , env : TestEnv ):
41
- env .set ('OPENAI_API_KEY' , 'test' )
41
+ @pytest .fixture
42
+ def create_test_module ():
43
+ def _create_test_module (** namespace : Any ) -> None :
44
+ assert 'test_module' not in sys .modules
42
45
43
- # Create a dynamic module using types.ModuleType
44
- import types
46
+ test_module = types .ModuleType ('test_module' )
47
+ for key , value in namespace .items ():
48
+ setattr (test_module , key , value )
45
49
46
- test_module = types . ModuleType ( 'test_module' )
50
+ sys . modules [ 'test_module' ] = test_module
47
51
48
- # Create and add agent to the module
49
- test_agent = Agent ()
50
- test_agent .model = TestModel (custom_output_text = 'Hello from custom agent' )
51
- setattr (test_module , 'custom_agent' , test_agent )
52
+ try :
53
+ yield _create_test_module
54
+ finally :
55
+ if 'test_module' in sys .modules :
56
+ del sys .modules ['test_module' ]
52
57
53
- # Register the module in sys.modules
54
- sys .modules ['test_module' ] = test_module
55
58
56
- try :
57
- # Mock ask_agent to avoid actual execution but capture the agent
58
- mock_ask = mocker .patch ('pydantic_ai._cli.ask_agent' )
59
+ def test_agent_flag (
60
+ capfd : CaptureFixture [str ],
61
+ mocker : MockerFixture ,
62
+ env : TestEnv ,
63
+ create_test_module : Callable [..., None ],
64
+ ):
65
+ env .remove ('OPENAI_API_KEY' )
59
66
60
- # Test CLI with custom agent
61
- assert cli ([ '--agent' , 'test_module: custom_agent' , 'hello' ]) == 0
67
+ test_agent = Agent ( TestModel ( custom_output_text = 'Hello from custom agent' ))
68
+ create_test_module ( custom_agent = test_agent )
62
69
63
- # Verify the output contains the custom agent message
64
- assert 'Using custom agent: test_module:custom_agent' in capfd . readouterr (). out
70
+ # Mock ask_agent to avoid actual execution but capture the agent
71
+ mock_ask = mocker . patch ( 'pydantic_ai._cli.ask_agent' )
65
72
66
- # Verify ask_agent was called with our custom agent
67
- mock_ask .assert_called_once ()
68
- assert mock_ask .call_args [0 ][0 ] is test_agent
73
+ # Test CLI with custom agent
74
+ assert cli (['--agent' , 'test_module:custom_agent' , 'hello' ]) == 0
69
75
70
- finally :
71
- # Clean up by removing the module from sys.modules
72
- if 'test_module' in sys .modules :
73
- del sys .modules ['test_module' ]
76
+ # Verify the output contains the custom agent message
77
+ assert 'using custom agent test_module:custom_agent' in capfd .readouterr ().out
74
78
79
+ # Verify ask_agent was called with our custom agent
80
+ mock_ask .assert_called_once ()
81
+ assert mock_ask .call_args [0 ][0 ] is test_agent
75
82
76
- def test_agent_flag_non_agent (capfd : CaptureFixture [str ], mocker : MockerFixture , env : TestEnv ):
77
- env .set ('OPENAI_API_KEY' , 'test' )
78
83
79
- # Create a dynamic module using types.ModuleType
80
- import types
84
+ def test_agent_flag_no_model (env : TestEnv , create_test_module : Callable [..., None ]):
85
+ env .remove ('OPENAI_API_KEY' )
86
+ test_agent = Agent ()
87
+ create_test_module (custom_agent = test_agent )
81
88
82
- test_module = types .ModuleType ('test_module' )
89
+ msg = 'The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable'
90
+ with pytest .raises (OpenAIError , match = msg ):
91
+ cli (['--agent' , 'test_module:custom_agent' , 'hello' ])
83
92
84
- # Create and add agent to the module
85
- test_agent = 'Not an Agent object'
86
- setattr (test_module , 'custom_agent' , test_agent )
87
93
88
- # Register the module in sys.modules
89
- sys .modules ['test_module' ] = test_module
94
+ def test_agent_flag_set_model (
95
+ capfd : CaptureFixture [str ],
96
+ mocker : MockerFixture ,
97
+ env : TestEnv ,
98
+ create_test_module : Callable [..., None ],
99
+ ):
100
+ env .set ('OPENAI_API_KEY' , 'xxx' )
90
101
91
- try :
92
- assert cli (['--agent' , 'test_module:custom_agent' , 'hello' ]) == 1
93
- assert 'is not an Agent' in capfd .readouterr ().out
102
+ custom_agent = Agent (TestModel (custom_output_text = 'Hello from custom agent' ))
103
+ create_test_module (custom_agent = custom_agent )
94
104
95
- finally :
96
- # Clean up by removing the module from sys.modules
97
- if 'test_module' in sys .modules :
98
- del sys .modules ['test_module' ]
105
+ mocker .patch ('pydantic_ai._cli.ask_agent' )
106
+
107
+ assert cli (['--agent' , 'test_module:custom_agent' , '--model' , 'gpt-4o' , 'hello' ]) == 0
108
+
109
+ assert 'using custom agent test_module:custom_agent with openai:gpt-4o' in capfd .readouterr ().out
110
+
111
+ assert isinstance (custom_agent .model , OpenAIModel )
112
+
113
+
114
+ def test_agent_flag_non_agent (
115
+ capfd : CaptureFixture [str ], mocker : MockerFixture , create_test_module : Callable [..., None ]
116
+ ):
117
+ test_agent = 'Not an Agent object'
118
+ create_test_module (custom_agent = test_agent )
119
+
120
+ assert cli (['--agent' , 'test_module:custom_agent' , 'hello' ]) == 1
121
+ assert 'is not an Agent' in capfd .readouterr ().out
99
122
100
123
101
124
def test_agent_flag_bad_module_variable_path (capfd : CaptureFixture [str ], mocker : MockerFixture , env : TestEnv ):
@@ -106,7 +129,7 @@ def test_agent_flag_bad_module_variable_path(capfd: CaptureFixture[str], mocker:
106
129
def test_list_models (capfd : CaptureFixture [str ]):
107
130
assert cli (['--list-models' ]) == 0
108
131
output = capfd .readouterr ().out .splitlines ()
109
- assert output [:2 ] == snapshot ([IsStr (regex = 'pai - PydanticAI CLI .* using openai:gpt-4o' ) , 'Available models:' ])
132
+ assert output [:3 ] == snapshot ([IsStr (regex = 'pai - PydanticAI CLI .*' ), '' , 'Available models:' ])
110
133
111
134
providers = (
112
135
'openai' ,
@@ -119,7 +142,7 @@ def test_list_models(capfd: CaptureFixture[str]):
119
142
'cohere' ,
120
143
'deepseek' ,
121
144
)
122
- models = {line .strip ().split (' ' )[0 ] for line in output [2 :]}
145
+ models = {line .strip ().split (' ' )[0 ] for line in output [3 :]}
123
146
for provider in providers :
124
147
models = models - {model for model in models if model .startswith (provider )}
125
148
assert models == set (), models
0 commit comments