Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Either way, running `clai` will start an interactive session where you can chat
## Help

```
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [-c] [--no-stream] [--version] [prompt]

Pydantic AI CLI v...

Expand All @@ -76,6 +76,7 @@ options:
-l, --list-models List all available models and exit
-t [CODE_THEME], --code-theme [CODE_THEME]
Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.
-c, --continue Continue last conversation, if any, instead of starting a new one.
--no-stream Disable streaming from the model
--version Show version and exit
```
51 changes: 45 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from pathlib import Path
from typing import Any, cast

from pydantic import ValidationError
from typing_inspection.introspection import get_literal_values

from . import __version__
from ._run_context import AgentDepsT
from .agent import AbstractAgent, Agent
from .exceptions import UserError
from .messages import ModelMessage, ModelResponse
from .messages import ModelMessage, ModelMessagesTypeAdapter, ModelResponse
from .models import KnownModelName, infer_model
from .output import OutputDataT

Expand Down Expand Up @@ -54,6 +55,7 @@
"""

PROMPT_HISTORY_FILENAME = 'prompt-history.txt'
LAST_CONVERSATION_FILENAME = 'last-conversation.json'


class SimpleCodeBlock(CodeBlock):
Expand Down Expand Up @@ -148,6 +150,13 @@ def cli( # noqa: C901
help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.',
default='dark',
)
parser.add_argument(
'-c',
'--continue',
dest='continue_',
action='store_true',
help='Continue last conversation, if any, instead of starting a new one.',
)
parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model')
parser.add_argument('--version', action='store_true', help='Show version and exit')

Expand Down Expand Up @@ -207,19 +216,42 @@ def cli( # noqa: C901
else:
code_theme = args.code_theme # pragma: no cover

try:
history = load_last_conversation() if args.continue_ else None
except ValidationError:
console.print(
'[red]Error loading last conversation, it is corrupted or invalid.\nStarting a new conversation.[/red]'
)
history = None

if prompt := cast(str, args.prompt):
try:
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme, messages=history))
except KeyboardInterrupt:
pass
return 0

try:
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name, message_history=history))
except KeyboardInterrupt: # pragma: no cover
return 0


def store_last_conversation(messages: list[ModelMessage], config_dir: Path | None = None) -> None:
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME
last_conversation_path.parent.mkdir(parents=True, exist_ok=True)
last_conversation_path.write_bytes(ModelMessagesTypeAdapter.dump_json(messages))


def load_last_conversation(config_dir: Path | None = None) -> list[ModelMessage] | None:
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME

if not last_conversation_path.exists():
return None

return ModelMessagesTypeAdapter.validate_json(last_conversation_path.read_bytes())


async def run_chat(
stream: bool,
agent: AbstractAgent[AgentDepsT, OutputDataT],
Expand Down Expand Up @@ -255,7 +287,7 @@ async def run_chat(
return exit_value
else:
try:
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages, config_dir)
except CancelledError: # pragma: no cover
console.print('[dim]Interrupted[/dim]')
except Exception as e: # pragma: no cover
Expand All @@ -273,6 +305,7 @@ async def ask_agent(
code_theme: str,
deps: AgentDepsT = None,
messages: Sequence[ModelMessage] | None = None,
config_dir: Path | None = None,
) -> list[ModelMessage]:
status = Status('[dim]Working on it…[/dim]', console=console)

Expand All @@ -281,7 +314,10 @@ async def ask_agent(
result = await agent.run(prompt, message_history=messages, deps=deps)
content = str(result.output)
console.print(Markdown(content, code_theme=code_theme))
return result.all_messages()
result_messages = result.all_messages()
store_last_conversation(result_messages, config_dir)

return result_messages

with status, ExitStack() as stack:
async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run:
Expand All @@ -296,7 +332,10 @@ async def ask_agent(
live.update(Markdown(str(content), code_theme=code_theme))

assert agent_run.result is not None
return agent_run.result.all_messages()
result_messages = agent_run.result.all_messages()
store_last_conversation(result_messages, config_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we restructure this a bit to not repeat these 2 lines?

Copy link
Author

@uriyyo uriyyo Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, tried to fix it


return result_messages


class CustomAutoSuggest(AutoSuggestFromHistory):
Expand Down
84 changes: 80 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import types
from collections.abc import Callable
from io import StringIO
from pathlib import Path
from typing import Any

import pytest
Expand All @@ -22,7 +23,7 @@
from prompt_toolkit.output import DummyOutput
from prompt_toolkit.shortcuts import PromptSession

from pydantic_ai._cli import cli, cli_agent, handle_slash_command
from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, cli, cli_agent, handle_slash_command
from pydantic_ai.models.openai import OpenAIChatModel

pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
Expand Down Expand Up @@ -56,6 +57,15 @@ def _create_test_module(**namespace: Any) -> None:
del sys.modules['test_module']


@pytest.fixture
def empty_last_conversation_path(tmp_path: Path, mocker: MockerFixture) -> Path:
path = tmp_path / LAST_CONVERSATION_FILENAME
mocker.patch('pydantic_ai._cli.PYDANTIC_AI_HOME', tmp_path)
path.unlink(missing_ok=True)

return path


def test_agent_flag(
capfd: CaptureFixture[str],
mocker: MockerFixture,
Expand Down Expand Up @@ -164,6 +174,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv):
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py'])


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
empty_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert empty_last_conversation_path.exists()
content = empty_last_conversation_path.read_text()
assert content

assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert empty_last_conversation_path.exists()
# verify that new content is appended to the file
assert len(empty_last_conversation_path.read_text()) > len(content)


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation_corrupted_file(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
empty_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
empty_last_conversation_path.write_text('not a valid json')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot(
[
IsStr(),
'Error loading last conversation, it is corrupted or invalid.',
'Starting a new conversation.',
'# world',
]
)
assert empty_last_conversation_path.exists()
assert empty_last_conversation_path.read_text()


def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')

Expand Down Expand Up @@ -264,21 +319,42 @@ def test_code_theme_unset(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli([])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'monokai',
'pai',
message_history=None,
)


def test_code_theme_light(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=light'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'default', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'default',
'pai',
message_history=None,
)


def test_code_theme_dark(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=dark'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'monokai',
'pai',
message_history=None,
)


def test_agent_to_cli_sync(mocker: MockerFixture, env: TestEnv):
Expand Down