|
1 | 1 | import sys |
2 | 2 | import types |
3 | 3 | from io import StringIO |
| 4 | +from pathlib import Path |
4 | 5 | from typing import Any, Callable |
5 | 6 |
|
6 | 7 | import pytest |
|
22 | 23 | from prompt_toolkit.output import DummyOutput |
23 | 24 | from prompt_toolkit.shortcuts import PromptSession |
24 | 25 |
|
25 | | - from pydantic_ai._cli import cli, cli_agent, handle_slash_command |
| 26 | + from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, PYDANTIC_AI_HOME, cli, cli_agent, handle_slash_command |
26 | 27 | from pydantic_ai.models.openai import OpenAIModel |
27 | 28 |
|
28 | 29 | pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests') |
@@ -56,6 +57,16 @@ def _create_test_module(**namespace: Any) -> None: |
56 | 57 | del sys.modules['test_module'] |
57 | 58 |
|
58 | 59 |
|
| 60 | +@pytest.fixture |
| 61 | +def emtpy_last_conversation_path(): |
| 62 | + path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME |
| 63 | + |
| 64 | + if path.exists(): |
| 65 | + path.unlink() |
| 66 | + |
| 67 | + return path |
| 68 | + |
| 69 | + |
59 | 70 | def test_agent_flag( |
60 | 71 | capfd: CaptureFixture[str], |
61 | 72 | mocker: MockerFixture, |
@@ -163,6 +174,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv): |
163 | 174 | assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py']) |
164 | 175 |
|
165 | 176 |
|
| 177 | +@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']]) |
| 178 | +def test_cli_continue_last_conversation( |
| 179 | + args: list[str], |
| 180 | + capfd: CaptureFixture[str], |
| 181 | + env: TestEnv, |
| 182 | + emtpy_last_conversation_path: Path, |
| 183 | +): |
| 184 | + env.set('OPENAI_API_KEY', 'test') |
| 185 | + with cli_agent.override(model=TestModel(custom_output_text='# world')): |
| 186 | + assert cli(args) == 0 |
| 187 | + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world']) |
| 188 | + assert emtpy_last_conversation_path.exists() |
| 189 | + content = emtpy_last_conversation_path.read_text() |
| 190 | + assert content |
| 191 | + |
| 192 | + assert cli(args) == 0 |
| 193 | + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world']) |
| 194 | + assert emtpy_last_conversation_path.exists() |
| 195 | + # verity that new content is appended to the file |
| 196 | + assert len(emtpy_last_conversation_path.read_text()) > len(content) |
| 197 | + |
| 198 | + |
| 199 | +@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']]) |
| 200 | +def test_cli_continue_last_conversation_corrupted_file( |
| 201 | + args: list[str], |
| 202 | + capfd: CaptureFixture[str], |
| 203 | + env: TestEnv, |
| 204 | + emtpy_last_conversation_path: Path, |
| 205 | +): |
| 206 | + env.set('OPENAI_API_KEY', 'test') |
| 207 | + emtpy_last_conversation_path.write_text('not a valid json') |
| 208 | + with cli_agent.override(model=TestModel(custom_output_text='# world')): |
| 209 | + assert cli(args) == 0 |
| 210 | + assert capfd.readouterr().out.splitlines() == snapshot( |
| 211 | + [ |
| 212 | + IsStr(), |
| 213 | + 'Error loading last conversation, it is corrupted or invalid. Starting a new ', |
| 214 | + 'conversation.', |
| 215 | + '# world', |
| 216 | + ] |
| 217 | + ) |
| 218 | + assert emtpy_last_conversation_path.exists() |
| 219 | + assert emtpy_last_conversation_path.read_text() |
| 220 | + |
| 221 | + |
166 | 222 | def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv): |
167 | 223 | env.set('OPENAI_API_KEY', 'test') |
168 | 224 | with create_pipe_input() as inp: |
|
0 commit comments