Skip to content
3 changes: 2 additions & 1 deletion clai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Either way, running `clai` will start an interactive session where you can chat
## Help

```
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [-c] [--no-stream] [--version] [prompt]

Pydantic AI CLI v...

Expand All @@ -76,6 +76,7 @@ options:
-l, --list-models List all available models and exit
-t [CODE_THEME], --code-theme [CODE_THEME]
Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.
-c, --continue Continue last conversation, if any, instead of starting a new one.
--no-stream Disable streaming from the model
--version Show version and exit
```
51 changes: 45 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from pathlib import Path
from typing import Any, cast

from pydantic import ValidationError
from typing_inspection.introspection import get_literal_values

from . import __version__
from ._run_context import AgentDepsT
from .agent import AbstractAgent, Agent
from .exceptions import UserError
from .messages import ModelMessage, ModelResponse
from .messages import ModelMessage, ModelMessagesTypeAdapter, ModelResponse
from .models import KnownModelName, infer_model
from .output import OutputDataT

Expand Down Expand Up @@ -54,6 +55,7 @@
"""

PROMPT_HISTORY_FILENAME = 'prompt-history.txt'
LAST_CONVERSATION_FILENAME = 'last-conversation.json'


class SimpleCodeBlock(CodeBlock):
Expand Down Expand Up @@ -148,6 +150,13 @@ def cli( # noqa: C901
help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.',
default='dark',
)
parser.add_argument(
'-c',
'--continue',
dest='continue_',
action='store_true',
help='Continue last conversation, if any, instead of starting a new one.',
)
parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model')
parser.add_argument('--version', action='store_true', help='Show version and exit')

Expand Down Expand Up @@ -207,19 +216,42 @@ def cli( # noqa: C901
else:
code_theme = args.code_theme # pragma: no cover

try:
history = load_last_conversation() if args.continue_ else None
except ValidationError:
console.print(
'[red]Error loading last conversation, it is corrupted or invalid.\nStarting a new conversation.[/red]'
)
history = None

if prompt := cast(str, args.prompt):
try:
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme, messages=history))
except KeyboardInterrupt:
pass
return 0

try:
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name, message_history=history))
except KeyboardInterrupt: # pragma: no cover
return 0


def store_last_conversation(messages: list[ModelMessage], config_dir: Path | None = None) -> None:
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME
last_conversation_path.parent.mkdir(parents=True, exist_ok=True)
last_conversation_path.write_bytes(ModelMessagesTypeAdapter.dump_json(messages))


def load_last_conversation(config_dir: Path | None = None) -> list[ModelMessage] | None:
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME

if not last_conversation_path.exists():
return None

return ModelMessagesTypeAdapter.validate_json(last_conversation_path.read_bytes())


async def run_chat(
stream: bool,
agent: AbstractAgent[AgentDepsT, OutputDataT],
Expand Down Expand Up @@ -255,7 +287,7 @@ async def run_chat(
return exit_value
else:
try:
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages, config_dir)
except CancelledError: # pragma: no cover
console.print('[dim]Interrupted[/dim]')
except Exception as e: # pragma: no cover
Expand All @@ -273,6 +305,7 @@ async def ask_agent(
code_theme: str,
deps: AgentDepsT = None,
messages: Sequence[ModelMessage] | None = None,
config_dir: Path | None = None,
) -> list[ModelMessage]:
status = Status('[dim]Working on it…[/dim]', console=console)

Expand All @@ -281,7 +314,10 @@ async def ask_agent(
result = await agent.run(prompt, message_history=messages, deps=deps)
content = str(result.output)
console.print(Markdown(content, code_theme=code_theme))
return result.all_messages()
result_messages = result.all_messages()
store_last_conversation(result_messages, config_dir)

return result_messages

with status, ExitStack() as stack:
async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run:
Expand All @@ -296,7 +332,10 @@ async def ask_agent(
live.update(Markdown(str(content), code_theme=code_theme))

assert agent_run.result is not None
return agent_run.result.all_messages()
result_messages = agent_run.result.all_messages()
store_last_conversation(result_messages, config_dir)

return result_messages


class CustomAutoSuggest(AutoSuggestFromHistory):
Expand Down
86 changes: 82 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import types
from collections.abc import Callable
from io import StringIO
from pathlib import Path
from typing import Any

import pytest
Expand All @@ -22,7 +23,7 @@
from prompt_toolkit.output import DummyOutput
from prompt_toolkit.shortcuts import PromptSession

from pydantic_ai._cli import cli, cli_agent, handle_slash_command
from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, cli, cli_agent, handle_slash_command
from pydantic_ai.models.openai import OpenAIChatModel

pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
Expand Down Expand Up @@ -56,6 +57,17 @@ def _create_test_module(**namespace: Any) -> None:
del sys.modules['test_module']


@pytest.fixture
def empty_last_conversation_path(tmp_path: Path, mocker: MockerFixture) -> Path:
path = tmp_path / LAST_CONVERSATION_FILENAME
mocker.patch('pydantic_ai._cli.PYDANTIC_AI_HOME', tmp_path)

if path.exists():
path.unlink()

return path


def test_agent_flag(
capfd: CaptureFixture[str],
mocker: MockerFixture,
Expand Down Expand Up @@ -164,6 +176,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv):
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py'])


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
empty_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert empty_last_conversation_path.exists()
content = empty_last_conversation_path.read_text()
assert content

assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert empty_last_conversation_path.exists()
# verify that new content is appended to the file
assert len(empty_last_conversation_path.read_text()) > len(content)


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation_corrupted_file(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
empty_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
empty_last_conversation_path.write_text('not a valid json')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot(
[
IsStr(),
'Error loading last conversation, it is corrupted or invalid.',
'Starting a new conversation.',
'# world',
]
)
assert empty_last_conversation_path.exists()
assert empty_last_conversation_path.read_text()


def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')

Expand Down Expand Up @@ -264,21 +321,42 @@ def test_code_theme_unset(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli([])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'monokai',
'pai',
message_history=None,
)


def test_code_theme_light(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=light'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'default', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'default',
'pai',
message_history=None,
)


def test_code_theme_dark(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=dark'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'monokai',
'pai',
message_history=None,
)


def test_agent_to_cli_sync(mocker: MockerFixture, env: TestEnv):
Expand Down
Loading