Skip to content

Commit 18ff146

Browse files
committed
Add tests
1 parent 1278090 commit 18ff146

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

tests/test_cli.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import types
33
from io import StringIO
4+
from pathlib import Path
45
from typing import Any, Callable
56

67
import pytest
@@ -22,7 +23,7 @@
2223
from prompt_toolkit.output import DummyOutput
2324
from prompt_toolkit.shortcuts import PromptSession
2425

25-
from pydantic_ai._cli import cli, cli_agent, handle_slash_command
26+
from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, PYDANTIC_AI_HOME, cli, cli_agent, handle_slash_command
2627
from pydantic_ai.models.openai import OpenAIModel
2728

2829
pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
@@ -56,6 +57,16 @@ def _create_test_module(**namespace: Any) -> None:
5657
del sys.modules['test_module']
5758

5859

60+
@pytest.fixture
61+
def emtpy_last_conversation_path():
62+
path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME
63+
64+
if path.exists():
65+
path.unlink()
66+
67+
return path
68+
69+
5970
def test_agent_flag(
6071
capfd: CaptureFixture[str],
6172
mocker: MockerFixture,
@@ -163,6 +174,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv):
163174
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py'])
164175

165176

177+
@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
178+
def test_cli_continue_last_conversation(
179+
args: list[str],
180+
capfd: CaptureFixture[str],
181+
env: TestEnv,
182+
emtpy_last_conversation_path: Path,
183+
):
184+
env.set('OPENAI_API_KEY', 'test')
185+
with cli_agent.override(model=TestModel(custom_output_text='# world')):
186+
assert cli(args) == 0
187+
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
188+
assert emtpy_last_conversation_path.exists()
189+
content = emtpy_last_conversation_path.read_text()
190+
assert content
191+
192+
assert cli(args) == 0
193+
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
194+
assert emtpy_last_conversation_path.exists()
195+
# verity that new content is appended to the file
196+
assert len(emtpy_last_conversation_path.read_text()) > len(content)
197+
198+
199+
@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
200+
def test_cli_continue_last_conversation_corrupted_file(
201+
args: list[str],
202+
capfd: CaptureFixture[str],
203+
env: TestEnv,
204+
emtpy_last_conversation_path: Path,
205+
):
206+
env.set('OPENAI_API_KEY', 'test')
207+
emtpy_last_conversation_path.write_text('not a valid json')
208+
with cli_agent.override(model=TestModel(custom_output_text='# world')):
209+
assert cli(args) == 0
210+
assert capfd.readouterr().out.splitlines() == snapshot(
211+
[
212+
IsStr(),
213+
'Error loading last conversation, it is corrupted or invalid. Starting a new ',
214+
'conversation.',
215+
'# world',
216+
]
217+
)
218+
assert emtpy_last_conversation_path.exists()
219+
assert emtpy_last_conversation_path.read_text()
220+
221+
166222
def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
167223
env.set('OPENAI_API_KEY', 'test')
168224
with create_pipe_input() as inp:

0 commit comments

Comments
 (0)