diff --git a/clai/README.md b/clai/README.md index da0ed9bc27..4f7ebc438f 100644 --- a/clai/README.md +++ b/clai/README.md @@ -54,7 +54,7 @@ Either way, running `clai` will start an interactive session where you can chat ## Help ``` -usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt] +usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [-c [CONTINUE_]] [--store STORE] [--no-stream] [--version] [prompt] Pydantic AI CLI v... @@ -76,6 +76,9 @@ options: -l, --list-models List all available models and exit -t [CODE_THEME], --code-theme [CODE_THEME] Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals. + -c [CONTINUE_], --continue [CONTINUE_] + Continue last conversation, if any, instead of starting a new one. + --store STORE Store the last conversation to the specified path instead of the default location. --no-stream Disable streaming from the model --version Show version and exit ``` diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index 4e4889bdfc..3cc8d8b8fd 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -12,13 +12,14 @@ from pathlib import Path from typing import Any, cast +from pydantic import ValidationError from typing_inspection.introspection import get_literal_values from . import __version__ from ._run_context import AgentDepsT from .agent import AbstractAgent, Agent from .exceptions import UserError -from .messages import ModelMessage, ModelResponse +from .messages import ModelMessage, ModelMessagesTypeAdapter, ModelResponse from .models import KnownModelName, infer_model from .output import OutputDataT @@ -54,6 +55,7 @@ """ PROMPT_HISTORY_FILENAME = 'prompt-history.txt' +LAST_CONVERSATION_FILENAME = 'last-conversation.json' class SimpleCodeBlock(CodeBlock): @@ -148,6 +150,20 @@ def cli( # noqa: C901 help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.', default='dark', ) + parser.add_argument( + '-c', + '--continue', + nargs='?', + dest='continue_', + const=str(PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME), + default=None, + help='Continue last conversation, if any, instead of starting a new one.', + ) + parser.add_argument( + '--store', + help='Store the last conversation to the specified path instead of the default location.', + default=None, + ) parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model') parser.add_argument('--version', action='store_true', help='Show version and exit') @@ -207,19 +223,49 @@ def cli( # noqa: C901 else: code_theme = args.code_theme # pragma: no cover + load_path: Path | None = None + if args.continue_: + load_path = Path(args.continue_) + + store_path: Path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME + if args.store: + store_path = Path(args.store) + + try: + history = load_conversation(load_path) if load_path else None + except ValidationError: + console.print( + '[red]Error loading conversation, it is corrupted or invalid.\nStarting a new conversation.[/red]' + ) + history = None + if prompt := cast(str, args.prompt): try: - asyncio.run(ask_agent(agent, prompt, stream, console, code_theme)) + asyncio.run(ask_agent(agent, prompt, stream, console, code_theme, messages=history, store_path=store_path)) except KeyboardInterrupt: pass return 0 try: - return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name)) + return asyncio.run( + run_chat(stream, agent, console, code_theme, prog_name, message_history=history, store_path=store_path) + ) except KeyboardInterrupt: # pragma: no cover return 0 +def store_conversation(messages: list[ModelMessage], store_path: Path) -> None: + store_path.parent.mkdir(parents=True, exist_ok=True) + store_path.write_bytes(ModelMessagesTypeAdapter.dump_json(messages)) + + +def load_conversation(load_path: Path) -> list[ModelMessage] | None: + if not load_path.exists(): + return None + + return ModelMessagesTypeAdapter.validate_json(load_path.read_bytes()) + + async def run_chat( stream: bool, agent: AbstractAgent[AgentDepsT, OutputDataT], @@ -229,6 +275,7 @@ async def run_chat( config_dir: Path | None = None, deps: AgentDepsT = None, message_history: Sequence[ModelMessage] | None = None, + store_path: Path | None = None, ) -> int: prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME prompt_history_path.parent.mkdir(parents=True, exist_ok=True) @@ -255,7 +302,7 @@ async def run_chat( return exit_value else: try: - messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages) + messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages, store_path) except CancelledError: # pragma: no cover console.print('[dim]Interrupted[/dim]') except Exception as e: # pragma: no cover @@ -273,6 +320,7 @@ async def ask_agent( code_theme: str, deps: AgentDepsT = None, messages: Sequence[ModelMessage] | None = None, + store_path: Path | None = None, ) -> list[ModelMessage]: status = Status('[dim]Working on it…[/dim]', console=console) @@ -281,22 +329,28 @@ async def ask_agent( result = await agent.run(prompt, message_history=messages, deps=deps) content = str(result.output) console.print(Markdown(content, code_theme=code_theme)) - return result.all_messages() - - with status, ExitStack() as stack: - async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run: - live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis') - async for node in agent_run: - if Agent.is_model_request_node(node): - async with node.stream(agent_run.ctx) as handle_stream: - status.stop() # stopping multiple times is idempotent - stack.enter_context(live) # entering multiple times is idempotent - - async for content in handle_stream.stream_output(debounce_by=None): - live.update(Markdown(str(content), code_theme=code_theme)) - - assert agent_run.result is not None - return agent_run.result.all_messages() + else: + with status, ExitStack() as stack: + async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run: + live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis') + async for node in agent_run: + if Agent.is_model_request_node(node): + async with node.stream(agent_run.ctx) as handle_stream: + status.stop() # stopping multiple times is idempotent + stack.enter_context(live) # entering multiple times is idempotent + + async for content in handle_stream.stream_output(debounce_by=None): + live.update(Markdown(str(content), code_theme=code_theme)) + + assert agent_run.result is not None + result = agent_run.result + + result_messages = result.all_messages() + + if store_path: + store_conversation(result_messages, store_path) + + return result_messages class CustomAutoSuggest(AutoSuggestFromHistory): diff --git a/tests/test_cli.py b/tests/test_cli.py index e95ff09141..7fc393873d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ import types from collections.abc import Callable from io import StringIO +from pathlib import Path from typing import Any import pytest @@ -22,7 +23,7 @@ from prompt_toolkit.output import DummyOutput from prompt_toolkit.shortcuts import PromptSession - from pydantic_ai._cli import cli, cli_agent, handle_slash_command + from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, cli, cli_agent, handle_slash_command from pydantic_ai.models.openai import OpenAIChatModel pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests') @@ -56,6 +57,15 @@ def _create_test_module(**namespace: Any) -> None: del sys.modules['test_module'] +@pytest.fixture +def empty_last_conversation_path(tmp_path: Path, mocker: MockerFixture) -> Path: + path = tmp_path / LAST_CONVERSATION_FILENAME + mocker.patch('pydantic_ai._cli.PYDANTIC_AI_HOME', tmp_path) + path.unlink(missing_ok=True) + + return path + + def test_agent_flag( capfd: CaptureFixture[str], mocker: MockerFixture, @@ -164,6 +174,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv): assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py']) +@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']]) +def test_cli_continue_last_conversation( + args: list[str], + capfd: CaptureFixture[str], + env: TestEnv, + empty_last_conversation_path: Path, +): + env.set('OPENAI_API_KEY', 'test') + with cli_agent.override(model=TestModel(custom_output_text='# world')): + assert cli(args) == 0 + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world']) + assert empty_last_conversation_path.exists() + content = empty_last_conversation_path.read_text() + assert content + + assert cli(args) == 0 + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world']) + assert empty_last_conversation_path.exists() + # verify that new content is appended to the file + assert len(empty_last_conversation_path.read_text()) > len(content) + + +@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']]) +def test_cli_continue_last_conversation_corrupted_file( + args: list[str], + capfd: CaptureFixture[str], + env: TestEnv, + empty_last_conversation_path: Path, +): + env.set('OPENAI_API_KEY', 'test') + empty_last_conversation_path.write_text('not a valid json') + with cli_agent.override(model=TestModel(custom_output_text='# world')): + assert cli(args) == 0 + assert capfd.readouterr().out.splitlines() == snapshot( + [ + IsStr(), + 'Error loading last conversation, it is corrupted or invalid.', + 'Starting a new conversation.', + '# world', + ] + ) + assert empty_last_conversation_path.exists() + assert empty_last_conversation_path.read_text() + + def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv): env.set('OPENAI_API_KEY', 'test') @@ -264,21 +319,42 @@ def test_code_theme_unset(mocker: MockerFixture, env: TestEnv): env.set('OPENAI_API_KEY', 'test') mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat') cli([]) - mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai') + mock_run_chat.assert_awaited_once_with( + True, + IsInstance(Agent), + IsInstance(Console), + 'monokai', + 'pai', + message_history=None, + ) def test_code_theme_light(mocker: MockerFixture, env: TestEnv): env.set('OPENAI_API_KEY', 'test') mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat') cli(['--code-theme=light']) - mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'default', 'pai') + mock_run_chat.assert_awaited_once_with( + True, + IsInstance(Agent), + IsInstance(Console), + 'default', + 'pai', + message_history=None, + ) def test_code_theme_dark(mocker: MockerFixture, env: TestEnv): env.set('OPENAI_API_KEY', 'test') mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat') cli(['--code-theme=dark']) - mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai') + mock_run_chat.assert_awaited_once_with( + True, + IsInstance(Agent), + IsInstance(Console), + 'monokai', + 'pai', + message_history=None, + ) def test_agent_to_cli_sync(mocker: MockerFixture, env: TestEnv):