-
Notifications
You must be signed in to change notification settings - Fork 1.4k
clai: Add ability to continue last conversation #2257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bf5d828
1278090
18ff146
f37042a
b4fa954
cf3b691
c487d58
c680761
526069f
a08b5ed
8ae1e62
1e40988
6dbe144
6db9293
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,7 +54,7 @@ Either way, running `clai` will start an interactive session where you can chat | |
| ## Help | ||
|
|
||
| ``` | ||
| usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt] | ||
| usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [-c [CONTINUE_]] [--store STORE] [--no-stream] [--version] [prompt] | ||
|
|
||
| Pydantic AI CLI v... | ||
|
|
||
|
|
@@ -76,6 +76,9 @@ options: | |
| -l, --list-models List all available models and exit | ||
| -t [CODE_THEME], --code-theme [CODE_THEME] | ||
| Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals. | ||
| -c [CONTINUE_], --continue [CONTINUE_] | ||
| Continue last conversation, if any, instead of starting a new one. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs updating now that we support an optional path |
||
| --store STORE Store the last conversation to the specified path instead of the default location. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| --no-stream Disable streaming from the model | ||
| --version Show version and exit | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,13 +12,14 @@ | |
| from pathlib import Path | ||
| from typing import Any, cast | ||
|
|
||
| from pydantic import ValidationError | ||
| from typing_inspection.introspection import get_literal_values | ||
|
|
||
| from . import __version__ | ||
| from ._run_context import AgentDepsT | ||
| from .agent import AbstractAgent, Agent | ||
| from .exceptions import UserError | ||
| from .messages import ModelMessage, ModelResponse | ||
| from .messages import ModelMessage, ModelMessagesTypeAdapter, ModelResponse | ||
| from .models import KnownModelName, infer_model | ||
| from .output import OutputDataT | ||
|
|
||
|
|
@@ -54,6 +55,7 @@ | |
| """ | ||
|
|
||
| PROMPT_HISTORY_FILENAME = 'prompt-history.txt' | ||
| LAST_CONVERSATION_FILENAME = 'last-conversation.json' | ||
|
|
||
|
|
||
| class SimpleCodeBlock(CodeBlock): | ||
|
|
@@ -148,6 +150,20 @@ def cli( # noqa: C901 | |
| help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.', | ||
| default='dark', | ||
| ) | ||
| parser.add_argument( | ||
| '-c', | ||
| '--continue', | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| nargs='?', | ||
| dest='continue_', | ||
| const=str(PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm so we're not respecting the |
||
| default=None, | ||
| help='Continue last conversation, if any, instead of starting a new one.', | ||
| ) | ||
| parser.add_argument( | ||
| '--store', | ||
| help='Store the last conversation to the specified path instead of the default location.', | ||
| default=None, | ||
| ) | ||
| parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model') | ||
| parser.add_argument('--version', action='store_true', help='Show version and exit') | ||
|
|
||
|
|
@@ -207,19 +223,49 @@ def cli( # noqa: C901 | |
| else: | ||
| code_theme = args.code_theme # pragma: no cover | ||
|
|
||
| load_path: Path | None = None | ||
| if args.continue_: | ||
| load_path = Path(args.continue_) | ||
|
|
||
| store_path: Path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME | ||
| if args.store: | ||
| store_path = Path(args.store) | ||
|
|
||
| try: | ||
| history = load_conversation(load_path) if load_path else None | ||
| except ValidationError: | ||
| console.print( | ||
| '[red]Error loading conversation, it is corrupted or invalid.\nStarting a new conversation.[/red]' | ||
| ) | ||
| history = None | ||
|
|
||
| if prompt := cast(str, args.prompt): | ||
| try: | ||
| asyncio.run(ask_agent(agent, prompt, stream, console, code_theme)) | ||
| asyncio.run(ask_agent(agent, prompt, stream, console, code_theme, messages=history, store_path=store_path)) | ||
| except KeyboardInterrupt: | ||
| pass | ||
| return 0 | ||
|
|
||
| try: | ||
| return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name)) | ||
| return asyncio.run( | ||
| run_chat(stream, agent, console, code_theme, prog_name, message_history=history, store_path=store_path) | ||
| ) | ||
| except KeyboardInterrupt: # pragma: no cover | ||
| return 0 | ||
|
|
||
|
|
||
| def store_conversation(messages: list[ModelMessage], store_path: Path) -> None: | ||
| store_path.parent.mkdir(parents=True, exist_ok=True) | ||
| store_path.write_bytes(ModelMessagesTypeAdapter.dump_json(messages)) | ||
|
|
||
|
|
||
| def load_conversation(load_path: Path) -> list[ModelMessage] | None: | ||
| if not load_path.exists(): | ||
| return None | ||
|
|
||
| return ModelMessagesTypeAdapter.validate_json(load_path.read_bytes()) | ||
|
|
||
|
|
||
| async def run_chat( | ||
| stream: bool, | ||
| agent: AbstractAgent[AgentDepsT, OutputDataT], | ||
|
|
@@ -229,6 +275,7 @@ async def run_chat( | |
| config_dir: Path | None = None, | ||
| deps: AgentDepsT = None, | ||
| message_history: Sequence[ModelMessage] | None = None, | ||
| store_path: Path | None = None, | ||
| ) -> int: | ||
| prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME | ||
| prompt_history_path.parent.mkdir(parents=True, exist_ok=True) | ||
|
|
@@ -255,7 +302,7 @@ async def run_chat( | |
| return exit_value | ||
| else: | ||
| try: | ||
| messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages) | ||
| messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages, store_path) | ||
| except CancelledError: # pragma: no cover | ||
| console.print('[dim]Interrupted[/dim]') | ||
| except Exception as e: # pragma: no cover | ||
|
|
@@ -273,6 +320,7 @@ async def ask_agent( | |
| code_theme: str, | ||
| deps: AgentDepsT = None, | ||
| messages: Sequence[ModelMessage] | None = None, | ||
| store_path: Path | None = None, | ||
| ) -> list[ModelMessage]: | ||
| status = Status('[dim]Working on it…[/dim]', console=console) | ||
|
|
||
|
|
@@ -281,22 +329,28 @@ async def ask_agent( | |
| result = await agent.run(prompt, message_history=messages, deps=deps) | ||
| content = str(result.output) | ||
| console.print(Markdown(content, code_theme=code_theme)) | ||
| return result.all_messages() | ||
|
|
||
| with status, ExitStack() as stack: | ||
| async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run: | ||
| live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis') | ||
| async for node in agent_run: | ||
| if Agent.is_model_request_node(node): | ||
| async with node.stream(agent_run.ctx) as handle_stream: | ||
| status.stop() # stopping multiple times is idempotent | ||
| stack.enter_context(live) # entering multiple times is idempotent | ||
|
|
||
| async for content in handle_stream.stream_output(debounce_by=None): | ||
| live.update(Markdown(str(content), code_theme=code_theme)) | ||
|
|
||
| assert agent_run.result is not None | ||
| return agent_run.result.all_messages() | ||
| else: | ||
| with status, ExitStack() as stack: | ||
| async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run: | ||
| live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis') | ||
| async for node in agent_run: | ||
| if Agent.is_model_request_node(node): | ||
| async with node.stream(agent_run.ctx) as handle_stream: | ||
| status.stop() # stopping multiple times is idempotent | ||
| stack.enter_context(live) # entering multiple times is idempotent | ||
|
|
||
| async for content in handle_stream.stream_output(debounce_by=None): | ||
| live.update(Markdown(str(content), code_theme=code_theme)) | ||
|
|
||
| assert agent_run.result is not None | ||
| result = agent_run.result | ||
|
|
||
| result_messages = result.all_messages() | ||
|
|
||
| if store_path: | ||
| store_conversation(result_messages, store_path) | ||
|
|
||
| return result_messages | ||
|
|
||
|
|
||
| class CustomAutoSuggest(AutoSuggestFromHistory): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the readme, this should really say
PATHrather thanCONTINUE_