Skip to content

Commit fa270d9

Browse files
authored
clai should not require OPENAI_API_KEY (#1724)
1 parent 2ccac46 commit fa270d9

File tree

2 files changed

+100
-69
lines changed

2 files changed

+100
-69
lines changed

pydantic_ai_slim/pydantic_ai/_cli.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def cli_exit(prog_name: str = 'pai'): # pragma: no cover
102102
sys.exit(cli(prog_name=prog_name))
103103

104104

105-
def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> int:
105+
def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> int: # noqa: C901
106106
"""Run the CLI and return the exit code for the process."""
107107
parser = argparse.ArgumentParser(
108108
prog=prog_name,
@@ -122,7 +122,6 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
122122
'--model',
123123
nargs='?',
124124
help='Model to use, in format "<provider>:<model>" e.g. "openai:gpt-4o" or "anthropic:claude-3-7-sonnet-latest". Defaults to "openai:gpt-4o".',
125-
default='openai:gpt-4o',
126125
)
127126
# we don't want to autocomplete or list models that don't include the provider,
128127
# e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
@@ -153,40 +152,49 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
153152
args = parser.parse_args(args_list)
154153

155154
console = Console()
156-
console.print(
157-
f'[green]{prog_name} - PydanticAI CLI v{__version__} using[/green] [magenta]{args.model}[/magenta]',
158-
highlight=False,
159-
)
155+
name_version = f'[green]{prog_name} - PydanticAI CLI v{__version__}[/green]'
160156
if args.version:
157+
console.print(name_version, highlight=False)
161158
return 0
162159
if args.list_models:
163-
console.print('Available models:', style='green bold')
160+
console.print(f'{name_version}\n\n[green]Available models:[/green]')
164161
for model in qualified_model_names:
165162
console.print(f' {model}', highlight=False)
166163
return 0
167164

168165
agent: Agent[None, str] = cli_agent
169166
if args.agent:
167+
sys.path.append(os.getcwd())
170168
try:
171-
current_path = os.getcwd()
172-
sys.path.append(current_path)
173-
174169
module_path, variable_name = args.agent.split(':')
175-
module = importlib.import_module(module_path)
176-
agent = getattr(module, variable_name)
177-
if not isinstance(agent, Agent):
178-
console.print(f'[red]Error: {args.agent} is not an Agent instance[/red]')
179-
return 1
180-
console.print(f'[green]Using custom agent:[/green] [magenta]{args.agent}[/magenta]', highlight=False)
181170
except ValueError:
182171
console.print('[red]Error: Agent must be specified in "module:variable" format[/red]')
183172
return 1
184173

185-
try:
186-
agent.model = infer_model(args.model)
187-
except UserError as e:
188-
console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]')
189-
return 1
174+
module = importlib.import_module(module_path)
175+
agent = getattr(module, variable_name)
176+
if not isinstance(agent, Agent):
177+
console.print(f'[red]Error: {args.agent} is not an Agent instance[/red]')
178+
return 1
179+
180+
model_arg_set = args.model is not None
181+
if agent.model is None or model_arg_set:
182+
try:
183+
agent.model = infer_model(args.model or 'openai:gpt-4o')
184+
except UserError as e:
185+
console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]')
186+
return 1
187+
188+
model_name = agent.model if isinstance(agent.model, str) else f'{agent.model.system}:{agent.model.model_name}'
189+
if args.agent and model_arg_set:
190+
console.print(
191+
f'{name_version} using custom agent [magenta]{args.agent}[/magenta] with [magenta]{model_name}[/magenta]',
192+
highlight=False,
193+
)
194+
elif args.agent:
195+
console.print(f'{name_version} using custom agent [magenta]{args.agent}[/magenta]', highlight=False)
196+
else:
197+
console.print(f'{name_version} with [magenta]{model_name}[/magenta]', highlight=False)
190198

191199
stream = not args.no_stream
192200
if args.code_theme == 'light':

tests/test_cli.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
2+
import types
23
from io import StringIO
3-
from typing import Any
4+
from typing import Any, Callable
45

56
import pytest
67
from dirty_equals import IsInstance, IsStr
@@ -16,11 +17,13 @@
1617
from .conftest import TestEnv, try_import
1718

1819
with try_import() as imports_successful:
20+
from openai import OpenAIError
1921
from prompt_toolkit.input import create_pipe_input
2022
from prompt_toolkit.output import DummyOutput
2123
from prompt_toolkit.shortcuts import PromptSession
2224

2325
from pydantic_ai._cli import cli, cli_agent, handle_slash_command
26+
from pydantic_ai.models.openai import OpenAIModel
2427

2528
pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
2629

@@ -32,70 +35,90 @@ def test_cli_version(capfd: CaptureFixture[str]):
3235

3336
def test_invalid_model(capfd: CaptureFixture[str]):
3437
assert cli(['--model', 'potato']) == 1
35-
assert capfd.readouterr().out.splitlines() == snapshot(
36-
[IsStr(), 'Error initializing potato:', 'Unknown model: potato']
37-
)
38+
assert capfd.readouterr().out.splitlines() == snapshot(['Error initializing potato:', 'Unknown model: potato'])
3839

3940

40-
def test_agent_flag(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
41-
env.set('OPENAI_API_KEY', 'test')
41+
@pytest.fixture
42+
def create_test_module():
43+
def _create_test_module(**namespace: Any) -> None:
44+
assert 'test_module' not in sys.modules
4245

43-
# Create a dynamic module using types.ModuleType
44-
import types
46+
test_module = types.ModuleType('test_module')
47+
for key, value in namespace.items():
48+
setattr(test_module, key, value)
4549

46-
test_module = types.ModuleType('test_module')
50+
sys.modules['test_module'] = test_module
4751

48-
# Create and add agent to the module
49-
test_agent = Agent()
50-
test_agent.model = TestModel(custom_output_text='Hello from custom agent')
51-
setattr(test_module, 'custom_agent', test_agent)
52+
try:
53+
yield _create_test_module
54+
finally:
55+
if 'test_module' in sys.modules:
56+
del sys.modules['test_module']
5257

53-
# Register the module in sys.modules
54-
sys.modules['test_module'] = test_module
5558

56-
try:
57-
# Mock ask_agent to avoid actual execution but capture the agent
58-
mock_ask = mocker.patch('pydantic_ai._cli.ask_agent')
59+
def test_agent_flag(
60+
capfd: CaptureFixture[str],
61+
mocker: MockerFixture,
62+
env: TestEnv,
63+
create_test_module: Callable[..., None],
64+
):
65+
env.remove('OPENAI_API_KEY')
5966

60-
# Test CLI with custom agent
61-
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 0
67+
test_agent = Agent(TestModel(custom_output_text='Hello from custom agent'))
68+
create_test_module(custom_agent=test_agent)
6269

63-
# Verify the output contains the custom agent message
64-
assert 'Using custom agent: test_module:custom_agent' in capfd.readouterr().out
70+
# Mock ask_agent to avoid actual execution but capture the agent
71+
mock_ask = mocker.patch('pydantic_ai._cli.ask_agent')
6572

66-
# Verify ask_agent was called with our custom agent
67-
mock_ask.assert_called_once()
68-
assert mock_ask.call_args[0][0] is test_agent
73+
# Test CLI with custom agent
74+
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 0
6975

70-
finally:
71-
# Clean up by removing the module from sys.modules
72-
if 'test_module' in sys.modules:
73-
del sys.modules['test_module']
76+
# Verify the output contains the custom agent message
77+
assert 'using custom agent test_module:custom_agent' in capfd.readouterr().out
7478

79+
# Verify ask_agent was called with our custom agent
80+
mock_ask.assert_called_once()
81+
assert mock_ask.call_args[0][0] is test_agent
7582

76-
def test_agent_flag_non_agent(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
77-
env.set('OPENAI_API_KEY', 'test')
7883

79-
# Create a dynamic module using types.ModuleType
80-
import types
84+
def test_agent_flag_no_model(env: TestEnv, create_test_module: Callable[..., None]):
85+
env.remove('OPENAI_API_KEY')
86+
test_agent = Agent()
87+
create_test_module(custom_agent=test_agent)
8188

82-
test_module = types.ModuleType('test_module')
89+
msg = 'The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable'
90+
with pytest.raises(OpenAIError, match=msg):
91+
cli(['--agent', 'test_module:custom_agent', 'hello'])
8392

84-
# Create and add agent to the module
85-
test_agent = 'Not an Agent object'
86-
setattr(test_module, 'custom_agent', test_agent)
8793

88-
# Register the module in sys.modules
89-
sys.modules['test_module'] = test_module
94+
def test_agent_flag_set_model(
95+
capfd: CaptureFixture[str],
96+
mocker: MockerFixture,
97+
env: TestEnv,
98+
create_test_module: Callable[..., None],
99+
):
100+
env.set('OPENAI_API_KEY', 'xxx')
90101

91-
try:
92-
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 1
93-
assert 'is not an Agent' in capfd.readouterr().out
102+
custom_agent = Agent(TestModel(custom_output_text='Hello from custom agent'))
103+
create_test_module(custom_agent=custom_agent)
94104

95-
finally:
96-
# Clean up by removing the module from sys.modules
97-
if 'test_module' in sys.modules:
98-
del sys.modules['test_module']
105+
mocker.patch('pydantic_ai._cli.ask_agent')
106+
107+
assert cli(['--agent', 'test_module:custom_agent', '--model', 'gpt-4o', 'hello']) == 0
108+
109+
assert 'using custom agent test_module:custom_agent with openai:gpt-4o' in capfd.readouterr().out
110+
111+
assert isinstance(custom_agent.model, OpenAIModel)
112+
113+
114+
def test_agent_flag_non_agent(
115+
capfd: CaptureFixture[str], mocker: MockerFixture, create_test_module: Callable[..., None]
116+
):
117+
test_agent = 'Not an Agent object'
118+
create_test_module(custom_agent=test_agent)
119+
120+
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 1
121+
assert 'is not an Agent' in capfd.readouterr().out
99122

100123

101124
def test_agent_flag_bad_module_variable_path(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
@@ -106,7 +129,7 @@ def test_agent_flag_bad_module_variable_path(capfd: CaptureFixture[str], mocker:
106129
def test_list_models(capfd: CaptureFixture[str]):
107130
assert cli(['--list-models']) == 0
108131
output = capfd.readouterr().out.splitlines()
109-
assert output[:2] == snapshot([IsStr(regex='pai - PydanticAI CLI .* using openai:gpt-4o'), 'Available models:'])
132+
assert output[:3] == snapshot([IsStr(regex='pai - PydanticAI CLI .*'), '', 'Available models:'])
110133

111134
providers = (
112135
'openai',
@@ -119,7 +142,7 @@ def test_list_models(capfd: CaptureFixture[str]):
119142
'cohere',
120143
'deepseek',
121144
)
122-
models = {line.strip().split(' ')[0] for line in output[2:]}
145+
models = {line.strip().split(' ')[0] for line in output[3:]}
123146
for provider in providers:
124147
models = models - {model for model in models if model.startswith(provider)}
125148
assert models == set(), models

0 commit comments

Comments
 (0)