Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion clai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Either way, running `clai` will start an interactive session where you can chat
## Help

```
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [-c [CONTINUE_]] [--store STORE] [--no-stream] [--version] [prompt]

Pydantic AI CLI v...

Expand All @@ -76,6 +76,9 @@ options:
-l, --list-models List all available models and exit
-t [CODE_THEME], --code-theme [CODE_THEME]
Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.
-c [CONTINUE_], --continue [CONTINUE_]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the readme, this should really say PATH rather than CONTINUE_

Continue last conversation, if any, instead of starting a new one.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs updating now that we support an optional path

--store STORE Store the last conversation to the specified path instead of the default location.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PATH would be better than STORE here as well

--no-stream Disable streaming from the model
--version Show version and exit
```
94 changes: 74 additions & 20 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from pathlib import Path
from typing import Any, cast

from pydantic import ValidationError
from typing_inspection.introspection import get_literal_values

from . import __version__
from ._run_context import AgentDepsT
from .agent import AbstractAgent, Agent
from .exceptions import UserError
from .messages import ModelMessage, ModelResponse
from .messages import ModelMessage, ModelMessagesTypeAdapter, ModelResponse
from .models import KnownModelName, infer_model
from .output import OutputDataT

Expand Down Expand Up @@ -54,6 +55,7 @@
"""

PROMPT_HISTORY_FILENAME = 'prompt-history.txt'
LAST_CONVERSATION_FILENAME = 'last-conversation.json'


class SimpleCodeBlock(CodeBlock):
Expand Down Expand Up @@ -148,6 +150,20 @@ def cli( # noqa: C901
help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.',
default='dark',
)
parser.add_argument(
'-c',
'--continue',
nargs='?',
dest='continue_',
const=str(PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so we're not respecting the config_dir anymore? Could we fix that

default=None,
help='Continue last conversation, if any, instead of starting a new one.',
)
parser.add_argument(
'--store',
help='Store the last conversation to the specified path instead of the default location.',
default=None,
)
parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model')
parser.add_argument('--version', action='store_true', help='Show version and exit')

Expand Down Expand Up @@ -207,19 +223,49 @@ def cli( # noqa: C901
else:
code_theme = args.code_theme # pragma: no cover

load_path: Path | None = None
if args.continue_:
load_path = Path(args.continue_)

store_path: Path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME
if args.store:
store_path = Path(args.store)

try:
history = load_conversation(load_path) if load_path else None
except ValidationError:
console.print(
'[red]Error loading conversation, it is corrupted or invalid.\nStarting a new conversation.[/red]'
)
history = None

if prompt := cast(str, args.prompt):
try:
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme, messages=history, store_path=store_path))
except KeyboardInterrupt:
pass
return 0

try:
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
return asyncio.run(
run_chat(stream, agent, console, code_theme, prog_name, message_history=history, store_path=store_path)
)
except KeyboardInterrupt: # pragma: no cover
return 0


def store_conversation(messages: list[ModelMessage], store_path: Path) -> None:
store_path.parent.mkdir(parents=True, exist_ok=True)
store_path.write_bytes(ModelMessagesTypeAdapter.dump_json(messages))


def load_conversation(load_path: Path) -> list[ModelMessage] | None:
if not load_path.exists():
return None

return ModelMessagesTypeAdapter.validate_json(load_path.read_bytes())


async def run_chat(
stream: bool,
agent: AbstractAgent[AgentDepsT, OutputDataT],
Expand All @@ -229,6 +275,7 @@ async def run_chat(
config_dir: Path | None = None,
deps: AgentDepsT = None,
message_history: Sequence[ModelMessage] | None = None,
store_path: Path | None = None,
) -> int:
prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME
prompt_history_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -255,7 +302,7 @@ async def run_chat(
return exit_value
else:
try:
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages, store_path)
except CancelledError: # pragma: no cover
console.print('[dim]Interrupted[/dim]')
except Exception as e: # pragma: no cover
Expand All @@ -273,6 +320,7 @@ async def ask_agent(
code_theme: str,
deps: AgentDepsT = None,
messages: Sequence[ModelMessage] | None = None,
store_path: Path | None = None,
) -> list[ModelMessage]:
status = Status('[dim]Working on it…[/dim]', console=console)

Expand All @@ -281,22 +329,28 @@ async def ask_agent(
result = await agent.run(prompt, message_history=messages, deps=deps)
content = str(result.output)
console.print(Markdown(content, code_theme=code_theme))
return result.all_messages()

with status, ExitStack() as stack:
async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run:
live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis')
async for node in agent_run:
if Agent.is_model_request_node(node):
async with node.stream(agent_run.ctx) as handle_stream:
status.stop() # stopping multiple times is idempotent
stack.enter_context(live) # entering multiple times is idempotent

async for content in handle_stream.stream_output(debounce_by=None):
live.update(Markdown(str(content), code_theme=code_theme))

assert agent_run.result is not None
return agent_run.result.all_messages()
else:
with status, ExitStack() as stack:
async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run:
live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis')
async for node in agent_run:
if Agent.is_model_request_node(node):
async with node.stream(agent_run.ctx) as handle_stream:
status.stop() # stopping multiple times is idempotent
stack.enter_context(live) # entering multiple times is idempotent

async for content in handle_stream.stream_output(debounce_by=None):
live.update(Markdown(str(content), code_theme=code_theme))

assert agent_run.result is not None
result = agent_run.result

result_messages = result.all_messages()

if store_path:
store_conversation(result_messages, store_path)

return result_messages


class CustomAutoSuggest(AutoSuggestFromHistory):
Expand Down
84 changes: 80 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import types
from collections.abc import Callable
from io import StringIO
from pathlib import Path
from typing import Any

import pytest
Expand All @@ -22,7 +23,7 @@
from prompt_toolkit.output import DummyOutput
from prompt_toolkit.shortcuts import PromptSession

from pydantic_ai._cli import cli, cli_agent, handle_slash_command
from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, cli, cli_agent, handle_slash_command
from pydantic_ai.models.openai import OpenAIChatModel

pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
Expand Down Expand Up @@ -56,6 +57,15 @@ def _create_test_module(**namespace: Any) -> None:
del sys.modules['test_module']


@pytest.fixture
def empty_last_conversation_path(tmp_path: Path, mocker: MockerFixture) -> Path:
path = tmp_path / LAST_CONVERSATION_FILENAME
mocker.patch('pydantic_ai._cli.PYDANTIC_AI_HOME', tmp_path)
path.unlink(missing_ok=True)

return path


def test_agent_flag(
capfd: CaptureFixture[str],
mocker: MockerFixture,
Expand Down Expand Up @@ -164,6 +174,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv):
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py'])


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
empty_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert empty_last_conversation_path.exists()
content = empty_last_conversation_path.read_text()
assert content

assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert empty_last_conversation_path.exists()
# verify that new content is appended to the file
assert len(empty_last_conversation_path.read_text()) > len(content)


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation_corrupted_file(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
empty_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
empty_last_conversation_path.write_text('not a valid json')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot(
[
IsStr(),
'Error loading last conversation, it is corrupted or invalid.',
'Starting a new conversation.',
'# world',
]
)
assert empty_last_conversation_path.exists()
assert empty_last_conversation_path.read_text()


def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')

Expand Down Expand Up @@ -264,21 +319,42 @@ def test_code_theme_unset(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli([])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'monokai',
'pai',
message_history=None,
)


def test_code_theme_light(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=light'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'default', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'default',
'pai',
message_history=None,
)


def test_code_theme_dark(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=dark'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(
True,
IsInstance(Agent),
IsInstance(Console),
'monokai',
'pai',
message_history=None,
)


def test_agent_to_cli_sync(mocker: MockerFixture, env: TestEnv):
Expand Down
Loading