From 237c47cb2301dbda88bea98a534335aaed906b72 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Tue, 5 Aug 2025 17:19:48 +0100 Subject: [PATCH 1/7] clai tui --- clai/clai/__init__.py | 2 +- clai/clai/__main__.py | 2 +- clai/clai/_cli.py | 360 +++++++++++++++++++++++++++++++++++++++ clai/clai/clai.tcss | 67 ++++++++ clai/clai/main_screen.py | 186 ++++++++++++++++++++ clai/clai/tui.py | 25 +++ clai/pyproject.toml | 1 + uv.lock | 71 +++++++- 8 files changed, 707 insertions(+), 7 deletions(-) create mode 100644 clai/clai/_cli.py create mode 100644 clai/clai/clai.tcss create mode 100644 clai/clai/main_screen.py create mode 100644 clai/clai/tui.py diff --git a/clai/clai/__init__.py b/clai/clai/__init__.py index ec2c025ff..cab848885 100644 --- a/clai/clai/__init__.py +++ b/clai/clai/__init__.py @@ -1,6 +1,6 @@ from importlib.metadata import version as _metadata_version -from pydantic_ai import _cli +from clai import _cli __all__ = '__version__', 'cli' __version__ = _metadata_version('clai') diff --git a/clai/clai/__main__.py b/clai/clai/__main__.py index 95e0dc6f4..d3e5fff9a 100644 --- a/clai/clai/__main__.py +++ b/clai/clai/__main__.py @@ -1,6 +1,6 @@ """This means `python -m clai` should run the CLI.""" -from pydantic_ai import _cli +from clai import _cli if __name__ == '__main__': _cli.cli_exit('clai') diff --git a/clai/clai/_cli.py b/clai/clai/_cli.py new file mode 100644 index 000000000..651de13f7 --- /dev/null +++ b/clai/clai/_cli.py @@ -0,0 +1,360 @@ +from __future__ import annotations as _annotations + +import argparse +import asyncio +import importlib +import os +import sys +from asyncio import CancelledError +from collections.abc import Sequence +from contextlib import ExitStack +from datetime import datetime, timezone +from importlib.metadata import version as _metadata_version +from pathlib import Path +from typing import Any, cast + +from typing_inspection.introspection import get_literal_values + +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.agent import Agent +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ModelMessage +from pydantic_ai.models import KnownModelName, infer_model +from pydantic_ai.output import OutputDataT + +__version__ = _metadata_version('clai') + +try: + import argcomplete + from prompt_toolkit import PromptSession + from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, Suggestion + from prompt_toolkit.buffer import Buffer + from prompt_toolkit.document import Document + from prompt_toolkit.history import FileHistory + from rich.console import Console, ConsoleOptions, RenderResult + from rich.live import Live + from rich.markdown import CodeBlock, Heading, Markdown + from rich.status import Status + from rich.style import Style + from rich.syntax import Syntax + from rich.text import Text +except ImportError as _import_error: + raise ImportError( + 'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the Pydantic AI CLI, ' + 'you can use the `cli` optional group — `pip install "pydantic-ai-slim[cli]"`' + ) from _import_error + + +__all__ = 'cli', 'cli_exit' + + +PYDANTIC_AI_HOME = Path.home() / '.pydantic-ai' +"""The home directory for Pydantic AI CLI. + +This folder is used to store the prompt history and configuration. +""" + +PROMPT_HISTORY_FILENAME = 'prompt-history.txt' + + +class SimpleCodeBlock(CodeBlock): + """Customized code blocks in markdown. + + This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix. + """ + + def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: + code = str(self.text).rstrip() + yield Text(self.lexer_name, style='dim') + yield Syntax(code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True) + yield Text(f'/{self.lexer_name}', style='dim') + + +class LeftHeading(Heading): + """Customized headings in markdown to stop centering and prepend markdown style hashes.""" + + def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: + # note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO + yield Text(f'{"#" * int(self.tag[1:])} {self.text.plain}', style=Style(bold=True)) + + +Markdown.elements.update( + fence=SimpleCodeBlock, + heading_open=LeftHeading, +) + + +cli_agent = Agent() + + +@cli_agent.system_prompt +def cli_system_prompt() -> str: + now_utc = datetime.now(timezone.utc) + tzinfo = now_utc.astimezone().tzinfo + tzname = tzinfo.tzname(now_utc) if tzinfo else '' + return f"""\ +Help the user by responding to their request, the output should be concise and always written in markdown. +The current date and time is {datetime.now()} {tzname}. +The user is running {sys.platform}.""" + + +def cli_exit(prog_name: str = 'pai'): # pragma: no cover + """Run the CLI and exit.""" + sys.exit(cli(prog_name=prog_name)) + + +def cli( # noqa: C901 + args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', default_model: str = 'openai:gpt-4.1' +) -> int: + """Run the CLI and return the exit code for the process.""" + parser = argparse.ArgumentParser( + prog=prog_name, + description=f"""\ +Pydantic AI CLI v{__version__}\n\n + +Special prompts: +* `/exit` - exit the interactive mode (ctrl-c and ctrl-d also work) +* `/markdown` - show the last markdown output of the last question +* `/multiline` - toggle multiline mode +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument('prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode') + arg = parser.add_argument( + '-m', + '--model', + nargs='?', + help=f'Model to use, in format ":" e.g. "openai:gpt-4.1" or "anthropic:claude-sonnet-4-0". Defaults to "{default_model}".', + ) + # we don't want to autocomplete or list models that don't include the provider, + # e.g. we want to show `openai:gpt-4o` but not `gpt-4o` + qualified_model_names = [n for n in get_literal_values(KnownModelName.__value__) if ':' in n] + arg.completer = argcomplete.ChoicesCompleter(qualified_model_names) # type: ignore[reportPrivateUsage] + parser.add_argument( + '-a', + '--agent', + help='Custom Agent to use, in format "module:variable", e.g. "mymodule.submodule:my_agent"', + ) + parser.add_argument( + '-l', + '--list-models', + action='store_true', + help='List all available models and exit', + ) + parser.add_argument( + '-t', + '--code-theme', + nargs='?', + help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.', + default='dark', + ) + parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model') + parser.add_argument('--version', action='store_true', help='Show version and exit') + + argcomplete.autocomplete(parser) + args = parser.parse_args(args_list) + + console = Console() + name_version = f'[green]{prog_name} - Pydantic AI CLI v{__version__}[/green]' + if args.version: + console.print(name_version, highlight=False) + return 0 + if args.list_models: + console.print(f'{name_version}\n\n[green]Available models:[/green]') + for model in qualified_model_names: + console.print(f' {model}', highlight=False) + return 0 + + agent: Agent[None, str] = cli_agent + if args.agent: + sys.path.append(os.getcwd()) + try: + module_path, variable_name = args.agent.split(':') + except ValueError: + console.print('[red]Error: Agent must be specified in "module:variable" format[/red]') + return 1 + + module = importlib.import_module(module_path) + agent = getattr(module, variable_name) + if not isinstance(agent, Agent): + console.print(f'[red]Error: {args.agent} is not an Agent instance[/red]') + return 1 + + model_arg_set = args.model is not None + if agent.model is None or model_arg_set: + try: + agent.model = infer_model(args.model or default_model) + except UserError as e: + console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]') + return 1 + + model_name = agent.model if isinstance(agent.model, str) else f'{agent.model.system}:{agent.model.model_name}' + if args.agent and model_arg_set: + console.print( + f'{name_version} using custom agent [magenta]{args.agent}[/magenta] with [magenta]{model_name}[/magenta]', + highlight=False, + ) + elif args.agent: + console.print(f'{name_version} using custom agent [magenta]{args.agent}[/magenta]', highlight=False) + else: + console.print(f'{name_version} with [magenta]{model_name}[/magenta]', highlight=False) + + stream = not args.no_stream + if args.code_theme == 'light': + code_theme = 'default' + elif args.code_theme == 'dark': + code_theme = 'monokai' + else: + code_theme = args.code_theme # pragma: no cover + + from clai.tui import CLAIApp + + app = CLAIApp(agent, args.prompt) + app.run() + return 0 + + if prompt := cast(str, args.prompt): + try: + asyncio.run(ask_agent(agent, prompt, stream, console, code_theme)) + except KeyboardInterrupt: + pass + return 0 + + try: + return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name)) + except KeyboardInterrupt: # pragma: no cover + return 0 + + +async def run_chat( + stream: bool, + agent: Agent[AgentDepsT, OutputDataT], + console: Console, + code_theme: str, + prog_name: str, + config_dir: Path | None = None, + deps: AgentDepsT = None, +) -> int: + prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME + prompt_history_path.parent.mkdir(parents=True, exist_ok=True) + prompt_history_path.touch(exist_ok=True) + session: PromptSession[Any] = PromptSession(history=FileHistory(str(prompt_history_path))) + + multiline = False + messages: list[ModelMessage] = [] + + while True: + try: + auto_suggest = CustomAutoSuggest(['/markdown', '/multiline', '/exit']) + text = await session.prompt_async(f'{prog_name} ➤ ', auto_suggest=auto_suggest, multiline=multiline) + except (KeyboardInterrupt, EOFError): # pragma: no cover + return 0 + + if not text.strip(): + continue + + ident_prompt = text.lower().strip().replace(' ', '-') + if ident_prompt.startswith('/'): + exit_value, multiline = handle_slash_command(ident_prompt, messages, multiline, console, code_theme) + if exit_value is not None: + return exit_value + else: + try: + messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages) + except CancelledError: # pragma: no cover + console.print('[dim]Interrupted[/dim]') + except Exception as e: # pragma: no cover + cause = getattr(e, '__cause__', None) + console.print(f'\n[red]{type(e).__name__}:[/red] {e}') + if cause: + console.print(f'[dim]Caused by: {cause}[/dim]') + + +async def ask_agent( + agent: Agent[AgentDepsT, OutputDataT], + prompt: str, + stream: bool, + console: Console, + code_theme: str, + deps: AgentDepsT = None, + messages: list[ModelMessage] | None = None, +) -> list[ModelMessage]: + status = Status('[dim]Working on it…[/dim]', console=console) + + if not stream: + with status: + result = await agent.run(prompt, message_history=messages, deps=deps) + content = str(result.output) + console.print(Markdown(content, code_theme=code_theme)) + return result.all_messages() + + with status, ExitStack() as stack: + async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run: + live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis') + async for node in agent_run: + if Agent.is_model_request_node(node): + async with node.stream(agent_run.ctx) as handle_stream: + status.stop() # stopping multiple times is idempotent + stack.enter_context(live) # entering multiple times is idempotent + + async for content in handle_stream.stream_output(debounce_by=None): + live.update(Markdown(str(content), code_theme=code_theme)) + + assert agent_run.result is not None + return agent_run.result.all_messages() + + +class CustomAutoSuggest(AutoSuggestFromHistory): + def __init__(self, special_suggestions: list[str] | None = None): + super().__init__() + self.special_suggestions = special_suggestions or [] + + def get_suggestion(self, buffer: Buffer, document: Document) -> Suggestion | None: # pragma: no cover + # Get the suggestion from history + suggestion = super().get_suggestion(buffer, document) + + # Check for custom suggestions + text = document.text_before_cursor.strip() + for special in self.special_suggestions: + if special.startswith(text): + return Suggestion(special[len(text) :]) + return suggestion + + +def handle_slash_command( + ident_prompt: str, messages: list[ModelMessage], multiline: bool, console: Console, code_theme: str +) -> tuple[int | None, bool]: + if ident_prompt == '/markdown': + try: + parts = messages[-1].parts + except IndexError: + console.print('[dim]No markdown output available.[/dim]') + else: + console.print('[dim]Markdown output of last question:[/dim]\n') + for part in parts: + if part.part_kind == 'text': + console.print( + Syntax( + part.content, + lexer='markdown', + theme=code_theme, + word_wrap=True, + background_color='default', + ) + ) + + elif ident_prompt == '/multiline': + multiline = not multiline + if multiline: + console.print( + 'Enabling multiline mode. [dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]' + ) + else: + console.print('Disabling multiline mode.') + return None, multiline + elif ident_prompt == '/exit': + console.print('[dim]Exiting…[/dim]') + return 0, multiline + else: + console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]') + return None, multiline diff --git a/clai/clai/clai.tcss b/clai/clai/clai.tcss new file mode 100644 index 000000000..9e98c226c --- /dev/null +++ b/clai/clai/clai.tcss @@ -0,0 +1,67 @@ +Conversation { + height: auto; + max-height: 1fr; + + #contents { + layout: stream; + height: 1fr; + } + + #contents > * { + margin-bottom: 1; + } + + Prompt { + height: auto; + padding: 0 0 0 1; + #prompt { + padding-left: 0; + color: $text-primary; + text-style: bold; + } + Input { + background: transparent; + padding: 0 1; + border: none; + height: 1; + } + } + + UserText { + background: black 10%; + padding: 1 0; + border-left: wide $success; + #prompt { + color: $text-muted; + } + #message { + color: $text-muted; + padding: 0 1; + } + } + + Response { + padding: 0 1 0 1; + & > MarkdownBlock { + padding: 0; + &:last-child { + margin-bottom:0; + } + } + } +} + +Footer { + background: black 10%; + .footer-key--key { + color: $text; + background: transparent; + text-style: bold; + padding: 0 1; + } + .footer-key--description { + padding: 0 1 0 0; + color: $text-muted; + background: $footer-description-background; + } +} \ No newline at end of file diff --git a/clai/clai/main_screen.py b/clai/clai/main_screen.py new file mode 100644 index 000000000..53a638088 --- /dev/null +++ b/clai/clai/main_screen.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import asyncio +from asyncio import Queue +from dataclasses import dataclass + +from textual import containers, getters, on, work +from textual.app import ComposeResult +from textual.binding import Binding +from textual.message import Message +from textual.screen import Screen +from textual.suggester import SuggestFromList +from textual.widget import Widget +from textual.widgets import Footer, Input, Label, Markdown + +from pydantic_ai import __version__ +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.agent import Agent +from pydantic_ai.messages import ModelMessage +from pydantic_ai.output import OutputDataT + +HELP = f"""\ +## Pydantic AI TUI **v{__version__}** + + +| Prompt | Purpose | +| --- | --- | +| `/markdown` | Show markdown output of last question. | +|`/multiline` | Enable multiline mode. | +| `/exit` | Exit CLAI. | + + +""" + + +class Response(Markdown): + """Response from the agent.""" + + +class UserText(containers.HorizontalGroup): + """Copy of what the user prompted.""" + + def __init__(self, prompt: str) -> None: + self._prompt = prompt + super().__init__() + + def compose(self) -> ComposeResult: + yield Label('clai ➤', id='prompt') + yield Label(self._prompt, id='message') + + +class PromptInput(Input): + """Custom prompt to disable maximize.""" + + BINDING_GROUP_TITLE = 'Prompt' + ALLOW_MAXIMIZE = False + + +class Prompt(containers.HorizontalGroup): + """Takes input from the user.""" + + def compose(self) -> ComposeResult: + yield Label('clai ➤', id='prompt') + yield PromptInput( + id='prompt-input', + suggester=SuggestFromList( + [ + '/markdown', + '/multiline', + '/exit', + ] + ), + ) + + +class Contents(containers.VerticalScroll): + BINDING_GROUP_TITLE = 'Conversation' + + BINDINGS = [Binding('tab', 'screen.focus-next', 'Focus prompt')] + + +class Conversation(containers.Vertical): + """The conversation with the AI.""" + + contents = getters.query_one('#contents', containers.VerticalScroll) + + @dataclass + class Prompt(Message): + """A prompt from the user.""" + + prompt: str + + def compose(self) -> ComposeResult: + with Contents(id='contents'): + pass + + yield Prompt(id='prompt') + + async def on_mount(self) -> None: + await self.post_help() + + async def post_help(self) -> None: + await self.post(Response(HELP)) + + async def post(self, widget: Widget) -> None: + await self.contents.mount(widget) + self.contents.anchor() + + async def post_prompt(self, prompt: str) -> None: + await self.post(UserText(prompt)) + + @on(Input.Submitted) + async def on_input_submitted(self, event: Input.Submitted) -> None: + event.stop() + self.post_message(self.Prompt(event.input.value)) + event.input.clear() + + +class MainScreen(Screen[None]): + """Main screen containing conversation.""" + + BINDING_GROUP_TITLE = 'Screen' + AUTO_FOCUS = 'Conversation Prompt Input' + + conversation = getters.query_one(Conversation) + + def __init__(self, agent: Agent[AgentDepsT, OutputDataT], prompt: str | None = None): + self.agent = agent + self.prompt = prompt + self.messages: list[ModelMessage] = [] + super().__init__() + + def compose(self) -> ComposeResult: + yield Conversation() + yield Footer() + + async def on_mount(self) -> None: + """Runs when the widget is mounted.""" + # Initialize the prompt queue + self.prompt_queue: Queue[str | None] = Queue(maxsize=10) + self.run_response_queue() + if self.prompt: + # Send initial prompt + await self.conversation.post_prompt(self.prompt) + await self.ask_agent(self.prompt) + + async def on_unmount(self) -> None: + """Called when the app exits.""" + # Tell the response queue task to finish up + await self.prompt_queue.put(None) + + @on(Conversation.Prompt) + async def on_conversation_prompt(self, event: Conversation.Prompt) -> None: + """Called when the user submits a prompt.""" + prompt = event.prompt + await self.conversation.post_prompt(prompt) + await self.ask_agent(prompt) + + async def ask_agent(self, prompt: str) -> None: + """Send the prompt to the agent.""" + await self.prompt_queue.put(prompt) + + async def post_response(self) -> Response: + """Post a response, returns a callable to append markdown.""" + response = Response() + response.display = False + await self.conversation.post(response) + return response + + @work + async def run_response_queue(self) -> None: + """Listens to the prompt queue, posts prompts, and streams the response.""" + while (prompt := await self.prompt_queue.get()) is not None: + response = await self.post_response() + markdown_stream = Markdown.get_stream(response) + try: + async with self.agent.iter(prompt, message_history=self.messages) as agent_run: + async for node in agent_run: + if Agent.is_model_request_node(node): + async with node.stream(agent_run.ctx) as handle_stream: + async for fragment in handle_stream.stream_text(delta=True, debounce_by=None): + await markdown_stream.write(fragment) + response.display = True + self.messages[:] = agent_run.result.all_messages() + finally: + await markdown_stream.stop() diff --git a/clai/clai/tui.py b/clai/clai/tui.py new file mode 100644 index 000000000..db250704b --- /dev/null +++ b/clai/clai/tui.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from textual.app import App + +from clai.main_screen import MainScreen +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.agent import Agent +from pydantic_ai.output import OutputDataT + + +class CLAIApp(App[None]): + """The CLA TUI app.""" + + CSS_PATH = 'clai.tcss' + + def __init__(self, agent: Agent[AgentDepsT, OutputDataT], prompt: str | None = None): + self._agent = agent + self._prompt = prompt + super().__init__() + + def on_load(self) -> None: + self.theme = 'nord' + + def get_default_screen(self) -> MainScreen: + return MainScreen(self._agent, self._prompt) diff --git a/clai/pyproject.toml b/clai/pyproject.toml index 411e616f0..5227818cd 100644 --- a/clai/pyproject.toml +++ b/clai/pyproject.toml @@ -49,6 +49,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ "pydantic-ai=={{ version }}", + "textual>=5.2.0" ] [tool.hatch.metadata] diff --git a/uv.lock b/uv.lock index 93701f47e..7317433eb 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -704,10 +704,14 @@ name = "clai" source = { editable = "clai" } dependencies = [ { name = "pydantic-ai" }, + { name = "textual" }, ] [package.metadata] -requires-dist = [{ name = "pydantic-ai", editable = "." }] +requires-dist = [ + { name = "pydantic-ai", editable = "." }, + { name = "textual", specifier = ">=5.2.0" }, +] [[package]] name = "click" @@ -1731,6 +1735,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/0e/b27cdbaccf30b890c40ed1da9fd4a3593a5cf94dae54fb34f8a4b74fcd3f/jsonschema_specifications-2025.4.1-py3-none-any.whl", hash = "sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af", size = 18437, upload-time = "2025-04-23T12:34:05.422Z" }, ] +[[package]] +name = "linkify-it-py" +version = "2.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "uc-micro-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/ae/bb56c6828e4797ba5a4821eec7c43b8bf40f69cda4d4f5f8c8a2810ec96a/linkify-it-py-2.0.3.tar.gz", hash = "sha256:68cda27e162e9215c17d786649d1da0021a451bdc436ef9e0fa0ba5234b9b048", size = 27946, upload-time = "2024-02-04T14:48:04.179Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/1e/b832de447dee8b582cac175871d2f6c3d5077cc56d5575cadba1fd1cccfa/linkify_it_py-2.0.3-py3-none-any.whl", hash = "sha256:6bcbc417b0ac14323382aef5c5192c0075bf8a9d6b41820a2b66371eac6b6d79", size = 19820, upload-time = "2024-02-04T14:48:02.496Z" }, +] + [[package]] name = "logfire" version = "4.0.0" @@ -1902,6 +1918,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528, upload-time = "2023-06-03T06:41:11.019Z" }, ] +[package.optional-dependencies] +linkify = [ + { name = "linkify-it-py" }, +] +plugins = [ + { name = "mdit-py-plugins" }, +] + [[package]] name = "markdownify" version = "0.14.1" @@ -2036,6 +2060,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f2/6f/94a7344f6d634fe3563bea8b33bccedee37f2726f7807e9a58440dc91627/mdformat-0.7.22-py3-none-any.whl", hash = "sha256:61122637c9e1d9be1329054f3fa216559f0d1f722b7919b060a8c2a4ae1850e5", size = 34447, upload-time = "2025-01-30T18:00:48.708Z" }, ] +[[package]] +name = "mdit-py-plugins" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/03/a2ecab526543b152300717cf232bb4bb8605b6edb946c845016fa9c9c9fd/mdit_py_plugins-0.4.2.tar.gz", hash = "sha256:5f2cd1fdb606ddf152d37ec30e46101a60512bc0e5fa1a7002c36647b09e26b5", size = 43542, upload-time = "2024-09-09T20:27:49.564Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/f7/7782a043553ee469c1ff49cfa1cdace2d6bf99a1f333cf38676b3ddf30da/mdit_py_plugins-0.4.2-py3-none-any.whl", hash = "sha256:0c673c3f889399a33b95e88d2f0d111b4447bdfea7f237dab2d488f459835636", size = 55316, upload-time = "2024-09-09T20:27:48.397Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -3716,11 +3752,11 @@ wheels = [ [[package]] name = "pygments" -version = "2.19.1" +version = "2.19.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7c/2d/c3338d48ea6cc0feb8446d8e6937e1408088a72a39937982cc6111d17f84/pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f", size = 4968581, upload-time = "2025-01-06T17:26:30.443Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293, upload-time = "2025-01-06T17:26:25.553Z" }, + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] [[package]] @@ -4429,6 +4465,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/3f/8ba87d9e287b9d385a02a7114ddcef61b26f86411e121c9003eb509a1773/tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687", size = 28165, upload-time = "2024-07-05T07:25:29.591Z" }, ] +[[package]] +name = "textual" +version = "5.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py", extra = ["linkify", "plugins"] }, + { name = "platformdirs" }, + { name = "pygments" }, + { name = "rich" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/55/2f7ede2f7a2aa852f062019ab96204dd311dda9103f0b73980e8aeb4e7cc/textual-5.2.0.tar.gz", hash = "sha256:6d4a77fddde7aa659acc25c3e14ed27ff75e125c6730c6f57263cf91b60e0d50", size = 1556452, upload-time = "2025-08-01T18:45:33.262Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/10/81b3c0cac7d5c6b0b32f38b663fdae3773fb2843623c1e0ff7626dfef7c9/textual-5.2.0-py3-none-any.whl", hash = "sha256:5df1ef970a88ac9a266cbcace098f595f22706990f2506f6e545d9bb55983f1d", size = 701823, upload-time = "2025-08-01T18:45:31.17Z" }, +] + [[package]] name = "tiktoken" version = "0.9.0" @@ -4738,6 +4790,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/dd/84f10e23edd882c6f968c21c2434fe67bd4a528967067515feca9e611e5e/tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639", size = 346762, upload-time = "2025-01-21T19:49:37.187Z" }, ] +[[package]] +name = "uc-micro-py" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/7a/146a99696aee0609e3712f2b44c6274566bc368dfe8375191278045186b8/uc-micro-py-1.0.3.tar.gz", hash = "sha256:d321b92cff673ec58027c04015fcaa8bb1e005478643ff4a500882eaab88c48a", size = 6043, upload-time = "2024-02-09T16:52:01.654Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/37/87/1f677586e8ac487e29672e4b17455758fce261de06a0d086167bb760361a/uc_micro_py-1.0.3-py3-none-any.whl", hash = "sha256:db1dffff340817673d7b466ec86114a9dc0e9d4d9b5ba229d9d60e5c12600cd5", size = 6229, upload-time = "2024-02-09T16:52:00.371Z" }, +] + [[package]] name = "urllib3" version = "1.26.20" From bd6d949d3645a35d3e6aa994e1379fbd0a05f6f4 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 7 Aug 2025 10:00:00 +0100 Subject: [PATCH 2/7] multiline, tidy --- clai/clai/_cli.py | 230 ++++--------------------------------- clai/clai/clai.tcss | 10 ++ clai/clai/main_screen.py | 238 +++++++++++++++++++++++++++++++++------ clai/clai/tui.py | 23 +++- 4 files changed, 254 insertions(+), 247 deletions(-) diff --git a/clai/clai/_cli.py b/clai/clai/_cli.py index 651de13f7..96a323d27 100644 --- a/clai/clai/_cli.py +++ b/clai/clai/_cli.py @@ -1,43 +1,25 @@ from __future__ import annotations as _annotations import argparse -import asyncio import importlib import os import sys -from asyncio import CancelledError from collections.abc import Sequence -from contextlib import ExitStack from datetime import datetime, timezone from importlib.metadata import version as _metadata_version from pathlib import Path -from typing import Any, cast from typing_inspection.introspection import get_literal_values -from pydantic_ai._run_context import AgentDepsT from pydantic_ai.agent import Agent from pydantic_ai.exceptions import UserError -from pydantic_ai.messages import ModelMessage from pydantic_ai.models import KnownModelName, infer_model -from pydantic_ai.output import OutputDataT __version__ = _metadata_version('clai') try: import argcomplete - from prompt_toolkit import PromptSession - from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, Suggestion - from prompt_toolkit.buffer import Buffer - from prompt_toolkit.document import Document - from prompt_toolkit.history import FileHistory - from rich.console import Console, ConsoleOptions, RenderResult - from rich.live import Live - from rich.markdown import CodeBlock, Heading, Markdown - from rich.status import Status - from rich.style import Style - from rich.syntax import Syntax - from rich.text import Text + from rich.console import Console except ImportError as _import_error: raise ImportError( 'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the Pydantic AI CLI, ' @@ -57,33 +39,6 @@ PROMPT_HISTORY_FILENAME = 'prompt-history.txt' -class SimpleCodeBlock(CodeBlock): - """Customized code blocks in markdown. - - This avoids a background color which messes up copy-pasting and sets the language name as dim prefix and suffix. - """ - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - code = str(self.text).rstrip() - yield Text(self.lexer_name, style='dim') - yield Syntax(code, self.lexer_name, theme=self.theme, background_color='default', word_wrap=True) - yield Text(f'/{self.lexer_name}', style='dim') - - -class LeftHeading(Heading): - """Customized headings in markdown to stop centering and prepend markdown style hashes.""" - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - # note we use `Style(bold=True)` not `self.style_name` here to disable underlining which is ugly IMHO - yield Text(f'{"#" * int(self.tag[1:])} {self.text.plain}', style=Style(bold=True)) - - -Markdown.elements.update( - fence=SimpleCodeBlock, - heading_open=LeftHeading, -) - - cli_agent = Agent() @@ -104,7 +59,10 @@ def cli_exit(prog_name: str = 'pai'): # pragma: no cover def cli( # noqa: C901 - args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', default_model: str = 'openai:gpt-4.1' + args_list: Sequence[str] | None = None, + *, + prog_name: str = 'pai', + default_model: str = 'openai:gpt-4.1', ) -> int: """Run the CLI and return the exit code for the process.""" parser = argparse.ArgumentParser( @@ -119,7 +77,9 @@ def cli( # noqa: C901 """, formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument('prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode') + parser.add_argument( + 'prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode' + ) arg = parser.add_argument( '-m', '--model', @@ -188,173 +148,25 @@ def cli( # noqa: C901 console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]') return 1 - model_name = agent.model if isinstance(agent.model, str) else f'{agent.model.system}:{agent.model.model_name}' + model_name = ( + agent.model + if isinstance(agent.model, str) + else f'{agent.model.system}:{agent.model.model_name}' + ) + title = name_version = f'{prog_name} - Pydantic AI CLI v{__version__}' if args.agent and model_arg_set: - console.print( - f'{name_version} using custom agent [magenta]{args.agent}[/magenta] with [magenta]{model_name}[/magenta]', - highlight=False, - ) + title = f'{name_version} using custom agent **{args.agent}** with `{model_name}`' + elif args.agent: - console.print(f'{name_version} using custom agent [magenta]{args.agent}[/magenta]', highlight=False) - else: - console.print(f'{name_version} with [magenta]{model_name}[/magenta]', highlight=False) + title = f'{name_version} using custom agent **{args.agent}**' - stream = not args.no_stream - if args.code_theme == 'light': - code_theme = 'default' - elif args.code_theme == 'dark': - code_theme = 'monokai' else: - code_theme = args.code_theme # pragma: no cover + title = f'{name_version} with **{model_name}**' from clai.tui import CLAIApp - app = CLAIApp(agent, args.prompt) + app = CLAIApp( + agent, PYDANTIC_AI_HOME / PROMPT_HISTORY_FILENAME, prompt=args.prompt, title=title + ) app.run() return 0 - - if prompt := cast(str, args.prompt): - try: - asyncio.run(ask_agent(agent, prompt, stream, console, code_theme)) - except KeyboardInterrupt: - pass - return 0 - - try: - return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name)) - except KeyboardInterrupt: # pragma: no cover - return 0 - - -async def run_chat( - stream: bool, - agent: Agent[AgentDepsT, OutputDataT], - console: Console, - code_theme: str, - prog_name: str, - config_dir: Path | None = None, - deps: AgentDepsT = None, -) -> int: - prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME - prompt_history_path.parent.mkdir(parents=True, exist_ok=True) - prompt_history_path.touch(exist_ok=True) - session: PromptSession[Any] = PromptSession(history=FileHistory(str(prompt_history_path))) - - multiline = False - messages: list[ModelMessage] = [] - - while True: - try: - auto_suggest = CustomAutoSuggest(['/markdown', '/multiline', '/exit']) - text = await session.prompt_async(f'{prog_name} ➤ ', auto_suggest=auto_suggest, multiline=multiline) - except (KeyboardInterrupt, EOFError): # pragma: no cover - return 0 - - if not text.strip(): - continue - - ident_prompt = text.lower().strip().replace(' ', '-') - if ident_prompt.startswith('/'): - exit_value, multiline = handle_slash_command(ident_prompt, messages, multiline, console, code_theme) - if exit_value is not None: - return exit_value - else: - try: - messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages) - except CancelledError: # pragma: no cover - console.print('[dim]Interrupted[/dim]') - except Exception as e: # pragma: no cover - cause = getattr(e, '__cause__', None) - console.print(f'\n[red]{type(e).__name__}:[/red] {e}') - if cause: - console.print(f'[dim]Caused by: {cause}[/dim]') - - -async def ask_agent( - agent: Agent[AgentDepsT, OutputDataT], - prompt: str, - stream: bool, - console: Console, - code_theme: str, - deps: AgentDepsT = None, - messages: list[ModelMessage] | None = None, -) -> list[ModelMessage]: - status = Status('[dim]Working on it…[/dim]', console=console) - - if not stream: - with status: - result = await agent.run(prompt, message_history=messages, deps=deps) - content = str(result.output) - console.print(Markdown(content, code_theme=code_theme)) - return result.all_messages() - - with status, ExitStack() as stack: - async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run: - live = Live('', refresh_per_second=15, console=console, vertical_overflow='ellipsis') - async for node in agent_run: - if Agent.is_model_request_node(node): - async with node.stream(agent_run.ctx) as handle_stream: - status.stop() # stopping multiple times is idempotent - stack.enter_context(live) # entering multiple times is idempotent - - async for content in handle_stream.stream_output(debounce_by=None): - live.update(Markdown(str(content), code_theme=code_theme)) - - assert agent_run.result is not None - return agent_run.result.all_messages() - - -class CustomAutoSuggest(AutoSuggestFromHistory): - def __init__(self, special_suggestions: list[str] | None = None): - super().__init__() - self.special_suggestions = special_suggestions or [] - - def get_suggestion(self, buffer: Buffer, document: Document) -> Suggestion | None: # pragma: no cover - # Get the suggestion from history - suggestion = super().get_suggestion(buffer, document) - - # Check for custom suggestions - text = document.text_before_cursor.strip() - for special in self.special_suggestions: - if special.startswith(text): - return Suggestion(special[len(text) :]) - return suggestion - - -def handle_slash_command( - ident_prompt: str, messages: list[ModelMessage], multiline: bool, console: Console, code_theme: str -) -> tuple[int | None, bool]: - if ident_prompt == '/markdown': - try: - parts = messages[-1].parts - except IndexError: - console.print('[dim]No markdown output available.[/dim]') - else: - console.print('[dim]Markdown output of last question:[/dim]\n') - for part in parts: - if part.part_kind == 'text': - console.print( - Syntax( - part.content, - lexer='markdown', - theme=code_theme, - word_wrap=True, - background_color='default', - ) - ) - - elif ident_prompt == '/multiline': - multiline = not multiline - if multiline: - console.print( - 'Enabling multiline mode. [dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]' - ) - else: - console.print('Disabling multiline mode.') - return None, multiline - elif ident_prompt == '/exit': - console.print('[dim]Exiting…[/dim]') - return 0, multiline - else: - console.print(f'[red]Unknown command[/red] [magenta]`{ident_prompt}`[/magenta]') - return None, multiline diff --git a/clai/clai/clai.tcss b/clai/clai/clai.tcss index 9e98c226c..64eeb4c43 100644 --- a/clai/clai/clai.tcss +++ b/clai/clai/clai.tcss @@ -25,6 +25,11 @@ Conversation { border: none; height: 1; } + TextArea { + background: transparent; + height: auto; + min-height: 3; + } } UserText { @@ -49,6 +54,11 @@ Conversation { } } } + + ErrorMessage { + background: $error 10%; + color: $text-error; + } } Footer { diff --git a/clai/clai/main_screen.py b/clai/clai/main_screen.py index 53a638088..0b5144242 100644 --- a/clai/clai/main_screen.py +++ b/clai/clai/main_screen.py @@ -1,36 +1,44 @@ from __future__ import annotations -import asyncio from asyncio import Queue from dataclasses import dataclass +from pathlib import Path +from string import Template +from prompt_toolkit.history import FileHistory from textual import containers, getters, on, work from textual.app import ComposeResult from textual.binding import Binding from textual.message import Message +from textual.reactive import reactive, var from textual.screen import Screen from textual.suggester import SuggestFromList from textual.widget import Widget -from textual.widgets import Footer, Input, Label, Markdown +from textual.widgets import Footer, Input, Label, Markdown, Static, TextArea +from textual.widgets.input import Selection -from pydantic_ai import __version__ from pydantic_ai._run_context import AgentDepsT from pydantic_ai.agent import Agent from pydantic_ai.messages import ModelMessage from pydantic_ai.output import OutputDataT -HELP = f"""\ -## Pydantic AI TUI **v{__version__}** +HELP = Template("""\ +## $title + +- **Powered by Pydantic AI** + + The Python agent framework designed to make it less painful to build production grade applications with Generative AI. - -| Prompt | Purpose | +| Command | Purpose | | --- | --- | | `/markdown` | Show markdown output of last question. | |`/multiline` | Enable multiline mode. | | `/exit` | Exit CLAI. | +""") -""" +class ErrorMessage(Static): + """An error message for the user.""" class Response(Markdown): @@ -59,10 +67,41 @@ class PromptInput(Input): class Prompt(containers.HorizontalGroup): """Takes input from the user.""" + BINDINGS = [ + Binding('shift+up', 'history(-1)', 'History up', priority=True), + Binding('shift+down', 'history(+1)', 'History down', priority=True), + Binding('ctrl+j', 'submit', 'Submit prompt', key_display='shift+⏎', priority=True), + Binding('escape', 'escape', 'Exit multiline'), + ] + + history_position = var(0, bindings=True) + multiline = reactive(False) + input = getters.query_one('#prompt-input', Input) + text_area = getters.query_one('#prompt-textarea', TextArea) + + @dataclass + class Submitted(Message): + """Prompt text was submitted.""" + + value: str + + @dataclass + class SetMultiline(Message): + """Go back to single line.""" + + multiline: bool + + def __init__(self, history: FileHistory, id: str | None = None) -> None: + self.history = history + self.history_strings: list[str] = [] + self.edit_prompt = '' + super().__init__(id=id) + def compose(self) -> ComposeResult: yield Label('clai ➤', id='prompt') yield PromptInput( id='prompt-input', + placeholder='Ask me anything', suggester=SuggestFromList( [ '/markdown', @@ -71,11 +110,108 @@ def compose(self) -> ComposeResult: ] ), ) + yield TextArea( + id='prompt-textarea', + language='markdown', + highlight_cursor_line=False, + ) + + def watch_multiline(self, multiline: bool) -> None: + if multiline: + self.input.display = False + self.text_area.display = True + self.text_area.load_text(self.input.value) + else: + self.input.display = True + self.text_area.display = False + self.input.value = self.text_area.text.partition('\n')[0] + + @property + def value(self) -> str: + """Value of prompt.""" + if self.multiline: + return self.text_area.text + else: + return self.input.value + + @value.setter + def value(self, value: str) -> None: + multiline = '\n' in value + self.post_message(self.SetMultiline(multiline)) + if multiline: + self.text_area.load_text(value) + else: + self.input.value = value + self.input.selection = Selection.cursor(len(value)) + + def set_prompt(self, prompt: str) -> None: + """Set a new prompt value.""" + self.value = prompt + + def clear(self) -> None: + self.input.clear() + self.text_area.load_text('') + + async def action_history(self, direction: int) -> None: + if self.history_position == 0: + self.history_strings.clear() + async for prompt in self.history.load(): + self.history_strings.append(prompt) + self.history_strings.reverse() + self.history_position = self.history_position + direction + + def action_submit(self) -> None: + self.post_message(self.Submitted(self.text_area.text)) + self.text_area.load_text('') + self.action_escape() + + def action_escape(self) -> None: + self.post_message(self.SetMultiline(False)) + + def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | None: + if action == 'history': + if parameters[0] == +1 and self.history_position == 0: + return None + if ( + parameters[0] == -1 + and self.history_strings + and self.history_position == -len(self.history_strings) + ): + return None + if action in ('submit', 'escape'): + return self.multiline + return True + + def validate_history_position(self, history_position: int) -> int: + if history_position > 0: + return 0 + if -history_position > len(self.history_strings): + return -len(self.history_strings) + return history_position + + async def watch_history_position(self, previous_position: int, position: int) -> None: + if previous_position == 0: + self.edit_prompt = self.query_one(PromptInput).value + if position == 0: + self.set_prompt(self.edit_prompt) + elif position < 0: + prompt = self.history_strings[position] + self.set_prompt(prompt) + + @on(Input.Submitted) + def on_input_submitted(self, event: Input.Submitted) -> None: + self.post_message(self.Submitted(event.value)) + self.clear() + + @on(Submitted) + async def on_submitted(self, event: Submitted) -> None: + self.history_position = 0 class Contents(containers.VerticalScroll): - BINDING_GROUP_TITLE = 'Conversation' + """The conversation contents.""" + BINDING_GROUP_TITLE = 'Conversation' BINDINGS = [Binding('tab', 'screen.focus-next', 'Focus prompt')] @@ -83,24 +219,27 @@ class Conversation(containers.Vertical): """The conversation with the AI.""" contents = getters.query_one('#contents', containers.VerticalScroll) + prompt = getters.query_one(Prompt) + multiline = reactive(False) - @dataclass - class Prompt(Message): - """A prompt from the user.""" - - prompt: str + def __init__(self, history: FileHistory, title: str) -> None: + self.history = history + self.title = title + super().__init__() def compose(self) -> ComposeResult: - with Contents(id='contents'): - pass + yield Contents(id='contents') + yield Prompt(self.history, id='prompt').data_bind(Conversation.multiline) - yield Prompt(id='prompt') + def get_last_markdown_source(self) -> str | None: + """Get the source of the last markdown response, or `None` if there is no markdown response.""" + for child in reversed(self.contents.children): + if isinstance(child, Markdown): + return child.source + return None async def on_mount(self) -> None: - await self.post_help() - - async def post_help(self) -> None: - await self.post(Response(HELP)) + await self.post(Response(HELP.safe_substitute(title=self.title))) async def post(self, widget: Widget) -> None: await self.contents.mount(widget) @@ -109,11 +248,9 @@ async def post(self, widget: Widget) -> None: async def post_prompt(self, prompt: str) -> None: await self.post(UserText(prompt)) - @on(Input.Submitted) - async def on_input_submitted(self, event: Input.Submitted) -> None: - event.stop() - self.post_message(self.Prompt(event.input.value)) - event.input.clear() + @on(Prompt.SetMultiline) + def on_escape_multiline(self, event: Prompt.SetMultiline) -> None: + self.multiline = event.multiline class MainScreen(Screen[None]): @@ -124,14 +261,23 @@ class MainScreen(Screen[None]): conversation = getters.query_one(Conversation) - def __init__(self, agent: Agent[AgentDepsT, OutputDataT], prompt: str | None = None): + def __init__( + self, + agent: Agent[AgentDepsT, OutputDataT], + history_path: Path, + title: str, + *, + prompt: str | None = None, + ): self.agent = agent self.prompt = prompt self.messages: list[ModelMessage] = [] + self.history = FileHistory(history_path) super().__init__() + self.title = title def compose(self) -> ComposeResult: - yield Conversation() + yield Conversation(self.history, self.title or 'PyDantic CLAI') yield Footer() async def on_mount(self) -> None: @@ -149,12 +295,34 @@ async def on_unmount(self) -> None: # Tell the response queue task to finish up await self.prompt_queue.put(None) - @on(Conversation.Prompt) - async def on_conversation_prompt(self, event: Conversation.Prompt) -> None: + @on(Prompt.Submitted) + async def on_conversation_prompt(self, event: Prompt.Submitted) -> None: """Called when the user submits a prompt.""" - prompt = event.prompt - await self.conversation.post_prompt(prompt) - await self.ask_agent(prompt) + prompt = event.value.strip() + if not prompt: + self.app.bell() + return + if prompt.startswith('/'): + await self.process_slash(prompt) + else: + self.history.append_string(prompt) + await self.conversation.post_prompt(prompt) + await self.ask_agent(prompt) + + async def process_slash(self, prompt: str) -> None: + prompt = prompt.strip() + if prompt == '/markdown': + markdown = self.conversation.get_last_markdown_source() + if not markdown: + await self.conversation.post(ErrorMessage('No markdown to display')) + else: + await self.conversation.post(Static(markdown)) + elif prompt == '/multiline': + self.conversation.multiline = not self.conversation.multiline + elif prompt == '/exit': + self.app.exit() + else: + await self.conversation.post(ErrorMessage(f'Unknown command: {prompt!r}')) async def ask_agent(self, prompt: str) -> None: """Send the prompt to the agent.""" @@ -178,7 +346,9 @@ async def run_response_queue(self) -> None: async for node in agent_run: if Agent.is_model_request_node(node): async with node.stream(agent_run.ctx) as handle_stream: - async for fragment in handle_stream.stream_text(delta=True, debounce_by=None): + async for fragment in handle_stream.stream_text( + delta=True, debounce_by=None + ): await markdown_stream.write(fragment) response.display = True self.messages[:] = agent_run.result.all_messages() diff --git a/clai/clai/tui.py b/clai/clai/tui.py index db250704b..fcdf6c4d7 100644 --- a/clai/clai/tui.py +++ b/clai/clai/tui.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pathlib import Path + from textual.app import App from clai.main_screen import MainScreen @@ -7,19 +9,32 @@ from pydantic_ai.agent import Agent from pydantic_ai.output import OutputDataT +DEFAULT_THEME = 'nord' + class CLAIApp(App[None]): """The CLA TUI app.""" + BINDING_GROUP_TITLE = 'App' CSS_PATH = 'clai.tcss' - def __init__(self, agent: Agent[AgentDepsT, OutputDataT], prompt: str | None = None): + def __init__( + self, + agent: Agent[AgentDepsT, OutputDataT], + history_path: Path, + prompt: str | None = None, + title: str | None = None, + ): + super().__init__() self._agent = agent + self.history_path = history_path + self.title = title or 'PyDantic CLAI' self._prompt = prompt - super().__init__() def on_load(self) -> None: - self.theme = 'nord' + """Called before application mode.""" + # Set the default theme here to avoid flash of different theme + self.theme = DEFAULT_THEME def get_default_screen(self) -> MainScreen: - return MainScreen(self._agent, self._prompt) + return MainScreen(self._agent, self.history_path, self.title, prompt=self._prompt) From b2f8c3d50651a25763fef8e0a22284fef928139b Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 7 Aug 2025 10:29:26 +0100 Subject: [PATCH 3/7] fix multiline --- clai/clai/main_screen.py | 57 ++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/clai/clai/main_screen.py b/clai/clai/main_screen.py index 0b5144242..6388ea398 100644 --- a/clai/clai/main_screen.py +++ b/clai/clai/main_screen.py @@ -9,6 +9,7 @@ from textual import containers, getters, on, work from textual.app import ComposeResult from textual.binding import Binding +from textual.geometry import clamp from textual.message import Message from textual.reactive import reactive, var from textual.screen import Screen @@ -32,7 +33,7 @@ | Command | Purpose | | --- | --- | | `/markdown` | Show markdown output of last question. | -|`/multiline` | Enable multiline mode. | +| `/multiline` | Enable multiline mode. | | `/exit` | Exit CLAI. | """) @@ -85,12 +86,6 @@ class Submitted(Message): value: str - @dataclass - class SetMultiline(Message): - """Go back to single line.""" - - multiline: bool - def __init__(self, history: FileHistory, id: str | None = None) -> None: self.history = history self.history_strings: list[str] = [] @@ -137,36 +132,37 @@ def value(self) -> str: @value.setter def value(self, value: str) -> None: multiline = '\n' in value - self.post_message(self.SetMultiline(multiline)) + self.multiline = multiline if multiline: self.text_area.load_text(value) else: self.input.value = value self.input.selection = Selection.cursor(len(value)) - def set_prompt(self, prompt: str) -> None: - """Set a new prompt value.""" - self.value = prompt - def clear(self) -> None: - self.input.clear() - self.text_area.load_text('') + with self.prevent(Input.Changed): + self.input.clear() + with self.prevent(TextArea.Changed): + self.text_area.load_text('') async def action_history(self, direction: int) -> None: if self.history_position == 0: self.history_strings.clear() async for prompt in self.history.load(): - self.history_strings.append(prompt) + if prompt.strip(): + self.history_strings.append(prompt) self.history_strings.reverse() self.history_position = self.history_position + direction def action_submit(self) -> None: self.post_message(self.Submitted(self.text_area.text)) - self.text_area.load_text('') + self.clear() self.action_escape() + self.history_position = 0 def action_escape(self) -> None: - self.post_message(self.SetMultiline(False)) + self.history_position = 0 + self.multiline = False def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | None: if action == 'history': @@ -183,28 +179,20 @@ def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | No return True def validate_history_position(self, history_position: int) -> int: - if history_position > 0: - return 0 - if -history_position > len(self.history_strings): - return -len(self.history_strings) - return history_position + return clamp(history_position, -len(self.history_strings), 0) async def watch_history_position(self, previous_position: int, position: int) -> None: if previous_position == 0: - self.edit_prompt = self.query_one(PromptInput).value + self.edit_prompt = self.value if position == 0: - self.set_prompt(self.edit_prompt) + self.value = self.edit_prompt elif position < 0: - prompt = self.history_strings[position] - self.set_prompt(prompt) + self.value = self.history_strings[position] @on(Input.Submitted) def on_input_submitted(self, event: Input.Submitted) -> None: self.post_message(self.Submitted(event.value)) self.clear() - - @on(Submitted) - async def on_submitted(self, event: Submitted) -> None: self.history_position = 0 @@ -220,7 +208,6 @@ class Conversation(containers.Vertical): contents = getters.query_one('#contents', containers.VerticalScroll) prompt = getters.query_one(Prompt) - multiline = reactive(False) def __init__(self, history: FileHistory, title: str) -> None: self.history = history @@ -229,7 +216,7 @@ def __init__(self, history: FileHistory, title: str) -> None: def compose(self) -> ComposeResult: yield Contents(id='contents') - yield Prompt(self.history, id='prompt').data_bind(Conversation.multiline) + yield Prompt(self.history, id='prompt') def get_last_markdown_source(self) -> str | None: """Get the source of the last markdown response, or `None` if there is no markdown response.""" @@ -248,10 +235,6 @@ async def post(self, widget: Widget) -> None: async def post_prompt(self, prompt: str) -> None: await self.post(UserText(prompt)) - @on(Prompt.SetMultiline) - def on_escape_multiline(self, event: Prompt.SetMultiline) -> None: - self.multiline = event.multiline - class MainScreen(Screen[None]): """Main screen containing conversation.""" @@ -302,10 +285,10 @@ async def on_conversation_prompt(self, event: Prompt.Submitted) -> None: if not prompt: self.app.bell() return + self.history.append_string(prompt) if prompt.startswith('/'): await self.process_slash(prompt) else: - self.history.append_string(prompt) await self.conversation.post_prompt(prompt) await self.ask_agent(prompt) @@ -318,7 +301,7 @@ async def process_slash(self, prompt: str) -> None: else: await self.conversation.post(Static(markdown)) elif prompt == '/multiline': - self.conversation.multiline = not self.conversation.multiline + self.conversation.prompt.multiline = not self.conversation.prompt.multiline elif prompt == '/exit': self.app.exit() else: From 9a5cca600db9a1d1856aeacda20e04a92a78b5a9 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 7 Aug 2025 13:41:48 +0100 Subject: [PATCH 4/7] prompt title --- clai/clai/main_screen.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/clai/clai/main_screen.py b/clai/clai/main_screen.py index 6388ea398..22aa09d1f 100644 --- a/clai/clai/main_screen.py +++ b/clai/clai/main_screen.py @@ -65,7 +65,13 @@ class PromptInput(Input): ALLOW_MAXIMIZE = False -class Prompt(containers.HorizontalGroup): +class PromptTextArea(TextArea): + """A custom textarea.""" + + BINDING_GROUP_TITLE = 'Prompt' + + +class Prompt(containers.HorizontalGroup, can_focus=False): """Takes input from the user.""" BINDINGS = [ @@ -105,7 +111,7 @@ def compose(self) -> ComposeResult: ] ), ) - yield TextArea( + yield PromptTextArea( id='prompt-textarea', language='markdown', highlight_cursor_line=False, @@ -116,10 +122,12 @@ def watch_multiline(self, multiline: bool) -> None: self.input.display = False self.text_area.display = True self.text_area.load_text(self.input.value) + self.text_area.focus() else: self.input.display = True self.text_area.display = False self.input.value = self.text_area.text.partition('\n')[0] + self.input.focus() @property def value(self) -> str: From c144c7ab0019a4fb6fb45e46fbd5e8c73ac76471 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 7 Aug 2025 13:49:36 +0100 Subject: [PATCH 5/7] bump textual --- clai/pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/clai/pyproject.toml b/clai/pyproject.toml index 5227818cd..a7b87f872 100644 --- a/clai/pyproject.toml +++ b/clai/pyproject.toml @@ -49,7 +49,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ "pydantic-ai=={{ version }}", - "textual>=5.2.0" + "textual>=5.3.0" ] [tool.hatch.metadata] diff --git a/uv.lock b/uv.lock index 7317433eb..bd38514e4 100644 --- a/uv.lock +++ b/uv.lock @@ -710,7 +710,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "pydantic-ai", editable = "." }, - { name = "textual", specifier = ">=5.2.0" }, + { name = "textual", specifier = ">=5.3.0" }, ] [[package]] @@ -4467,7 +4467,7 @@ wheels = [ [[package]] name = "textual" -version = "5.2.0" +version = "5.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py", extra = ["linkify", "plugins"] }, @@ -4476,9 +4476,9 @@ dependencies = [ { name = "rich" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/41/55/2f7ede2f7a2aa852f062019ab96204dd311dda9103f0b73980e8aeb4e7cc/textual-5.2.0.tar.gz", hash = "sha256:6d4a77fddde7aa659acc25c3e14ed27ff75e125c6730c6f57263cf91b60e0d50", size = 1556452, upload-time = "2025-08-01T18:45:33.262Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/ce/f0f938d33d9bebbf8629e0020be00c560ddfa90a23ebe727c2e5aa3f30cf/textual-5.3.0.tar.gz", hash = "sha256:1b6128b339adef2e298cc23ab4777180443240ece5c232f29b22960efd658d4d", size = 1557651, upload-time = "2025-08-07T12:36:50.342Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/10/81b3c0cac7d5c6b0b32f38b663fdae3773fb2843623c1e0ff7626dfef7c9/textual-5.2.0-py3-none-any.whl", hash = "sha256:5df1ef970a88ac9a266cbcace098f595f22706990f2506f6e545d9bb55983f1d", size = 701823, upload-time = "2025-08-01T18:45:31.17Z" }, + { url = "https://files.pythonhosted.org/packages/00/2f/f7c8a533bee50fbf5bb37ffc1621e7b2cdd8c9a6301fc51faa35fa50b09d/textual-5.3.0-py3-none-any.whl", hash = "sha256:02a6abc065514c4e21f94e79aaecea1f78a28a85d11d7bfc64abf3392d399890", size = 702671, upload-time = "2025-08-07T12:36:48.272Z" }, ] [[package]] From 0130c1a910158d051fff48e274b619804d1c4ced Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 7 Aug 2025 14:01:35 +0100 Subject: [PATCH 6/7] precommit changes --- clai/clai/clai.tcss | 36 ++++++++++++++++++------------------ clai/clai/main_screen.py | 4 ++-- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/clai/clai/clai.tcss b/clai/clai/clai.tcss index 64eeb4c43..8e204e7b6 100644 --- a/clai/clai/clai.tcss +++ b/clai/clai/clai.tcss @@ -1,22 +1,22 @@ -Conversation { - height: auto; - max-height: 1fr; - - #contents { +Conversation { + height: auto; + max-height: 1fr; + + #contents { layout: stream; - height: 1fr; + height: 1fr; } #contents > * { - margin-bottom: 1; + margin-bottom: 1; } - Prompt { - height: auto; - padding: 0 0 0 1; + Prompt { + height: auto; + padding: 0 0 0 1; #prompt { padding-left: 0; - color: $text-primary; + color: $text-primary; text-style: bold; } Input { @@ -41,13 +41,13 @@ Conversation { } #message { color: $text-muted; - padding: 0 1; - } + padding: 0 1; + } } - Response { - padding: 0 1 0 1; - & > MarkdownBlock { + Response { + padding: 0 1 0 1; + & > MarkdownBlock { padding: 0; &:last-child { margin-bottom:0; @@ -61,7 +61,7 @@ Conversation { } } -Footer { +Footer { background: black 10%; .footer-key--key { color: $text; @@ -74,4 +74,4 @@ Footer { color: $text-muted; background: $footer-description-background; } -} \ No newline at end of file +} diff --git a/clai/clai/main_screen.py b/clai/clai/main_screen.py index 22aa09d1f..68ca92c5c 100644 --- a/clai/clai/main_screen.py +++ b/clai/clai/main_screen.py @@ -25,9 +25,9 @@ HELP = Template("""\ ## $title - + - **Powered by Pydantic AI** - + The Python agent framework designed to make it less painful to build production grade applications with Generative AI. | Command | Purpose | From 07bb7fd5fc689524143f8190f67c9fa8ce57757f Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 11 Aug 2025 09:30:51 +0200 Subject: [PATCH 7/7] fix linter --- clai/clai/_cli.py | 16 ++++------------ clai/clai/main_screen.py | 10 ++-------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/clai/clai/_cli.py b/clai/clai/_cli.py index 96a323d27..84cbe89e8 100644 --- a/clai/clai/_cli.py +++ b/clai/clai/_cli.py @@ -58,7 +58,7 @@ def cli_exit(prog_name: str = 'pai'): # pragma: no cover sys.exit(cli(prog_name=prog_name)) -def cli( # noqa: C901 +def cli( args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', @@ -77,9 +77,7 @@ def cli( # noqa: C901 """, formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument( - 'prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode' - ) + parser.add_argument('prompt', nargs='?', help='AI Prompt, if omitted fall into interactive mode') arg = parser.add_argument( '-m', '--model', @@ -148,11 +146,7 @@ def cli( # noqa: C901 console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]') return 1 - model_name = ( - agent.model - if isinstance(agent.model, str) - else f'{agent.model.system}:{agent.model.model_name}' - ) + model_name = agent.model if isinstance(agent.model, str) else f'{agent.model.system}:{agent.model.model_name}' title = name_version = f'{prog_name} - Pydantic AI CLI v{__version__}' if args.agent and model_arg_set: title = f'{name_version} using custom agent **{args.agent}** with `{model_name}`' @@ -165,8 +159,6 @@ def cli( # noqa: C901 from clai.tui import CLAIApp - app = CLAIApp( - agent, PYDANTIC_AI_HOME / PROMPT_HISTORY_FILENAME, prompt=args.prompt, title=title - ) + app = CLAIApp(agent, PYDANTIC_AI_HOME / PROMPT_HISTORY_FILENAME, prompt=args.prompt, title=title) app.run() return 0 diff --git a/clai/clai/main_screen.py b/clai/clai/main_screen.py index 68ca92c5c..f1afef211 100644 --- a/clai/clai/main_screen.py +++ b/clai/clai/main_screen.py @@ -176,11 +176,7 @@ def check_action(self, action: str, parameters: tuple[object, ...]) -> bool | No if action == 'history': if parameters[0] == +1 and self.history_position == 0: return None - if ( - parameters[0] == -1 - and self.history_strings - and self.history_position == -len(self.history_strings) - ): + if parameters[0] == -1 and self.history_strings and self.history_position == -len(self.history_strings): return None if action in ('submit', 'escape'): return self.multiline @@ -337,9 +333,7 @@ async def run_response_queue(self) -> None: async for node in agent_run: if Agent.is_model_request_node(node): async with node.stream(agent_run.ctx) as handle_stream: - async for fragment in handle_stream.stream_text( - delta=True, debounce_by=None - ): + async for fragment in handle_stream.stream_text(delta=True, debounce_by=None): await markdown_stream.write(fragment) response.display = True self.messages[:] = agent_run.result.all_messages()