diff --git a/CHANGELOG.md b/CHANGELOG.md index 55638d52b..24d64d85a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Both `ui.Chat()` and `ui.MarkdownStream()` now support arbirary Shiny UI elements inside of messages. This allows for gathering input from the user (e.g., `ui.input_select()`), displaying of rich output (e.g., `render.DataGrid()`), and more. (#1868) +* Added a new `.message_stream_context()` method to `ui.Chat()`. This context manager is a useful alternative to `.append_message_stream()` when you want to: (1) Nest a stream within another and/or +(2) Overwrite/replace streaming content. (#1906) + ### Changes * Express mode's `app_opts()` requires all arguments to be keyword-only. If you are using positional arguments, you will need to update your code. (#1895) diff --git a/shiny/ui/_chat.py b/shiny/ui/_chat.py index c9cc6f960..b4024f4f3 100644 --- a/shiny/ui/_chat.py +++ b/shiny/ui/_chat.py @@ -1,6 +1,7 @@ from __future__ import annotations import inspect +from contextlib import asynccontextmanager from typing import ( Any, AsyncIterable, @@ -81,7 +82,12 @@ ChunkOption = Literal["start", "end", True, False] -PendingMessage = Tuple[Any, ChunkOption, Union[str, None]] +PendingMessage = Tuple[ + Any, + ChunkOption, + Literal["append", "replace"], + Union[str, None], +] @add_example(ex_dir="../templates/chat/starters/hello") @@ -197,10 +203,13 @@ def __init__( self.on_error = on_error # Chunked messages get accumulated (using this property) before changing state - self._current_stream_message = "" + self._current_stream_message: str = "" self._current_stream_id: str | None = None self._pending_messages: list[PendingMessage] = [] + # For tracking message stream state when entering/exiting nested streams + self._message_stream_checkpoint: str = "" + # If a user input message is transformed into a response, we need to cancel # the next user input submit handling self._suspend_input_handler: bool = False @@ -251,10 +260,7 @@ async def _on_user_input(): else: # A transformed value of None is a special signal to suspend input # handling (i.e., don't generate a response) - self._store_message( - TransformedMessage.from_chat_message(msg), - index=n_pre, - ) + self._store_message(msg, index=n_pre) await self._remove_loading_message() self._suspend_input_handler = True @@ -573,49 +579,165 @@ async def append_message( similar) is specified in model's completion method. ::: """ - await self._append_message(message, icon=icon) + # If we're in a stream, queue the message + if self._current_stream_id: + self._pending_messages.append((message, False, "append", None)) + return + + msg = normalize_message(message) + msg = await self._transform_message(msg) + if msg is None: + return + self._store_message(msg) + await self._send_append_message( + message=msg, + chunk=False, + icon=icon, + ) + + @asynccontextmanager + async def message_stream_context(self): + """ + Message stream context manager. + + A context manager for appending streaming messages into the chat. This context + manager can: + + 1. Be used in isolation to append a new streaming message to the chat. + * Compared to `.append_message_stream()` this method is more flexible but + isn't non-blocking by default (i.e., it doesn't launch an extended task). + 2. Be nested within itself + * Nesting is primarily useful for making checkpoints to `.clear()` back + to (see the example below). + 3. Be used from within a `.append_message_stream()` + * Useful for inserting additional content from another context into the + stream (e.g., see the note about tool calls below). + + Yields + ------ + : + A `MessageStream` class instance, which has a method for `.append()`ing + message content chunks to as well as way to `.clear()` the stream back to + it's initial state. Note that `.append()` supports the same message content + types as `.append_message()`. + + Example + ------- + ```python + import asyncio + + from shiny import reactive + from shiny.express import ui + + chat = ui.Chat(id="my_chat") + chat.ui() + + @reactive.effect + async def _(): + async with chat.message_stream_context() as msg: + await msg.append("Starting stream...\n\nProgress:") + async with chat.message_stream_context() as progress: + for x in [0, 50, 100]: + await progress.append(f" {x}%") + await asyncio.sleep(1) + await progress.clear() + await msg.clear() + await msg.append("Completed stream") + ``` - async def _append_message( + Note + ---- + A useful pattern for displaying tool calls in a chatbot is for the tool to + display using `.message_stream_context()` while the the response generation is + happening through `.append_message_stream()`. This allows the tool to display + things like progress updates (or other "ephemeral" content) and optionally + `.clear()` the stream back to it's initial state when ready to display the + "final" content. + """ + # Checkpoint the current stream state so operation="replace" can return to it + old_checkpoint = self._message_stream_checkpoint + self._message_stream_checkpoint = self._current_stream_message + + # No stream currently exists, start one + stream_id = self._current_stream_id + is_root_stream = stream_id is None + if is_root_stream: + stream_id = _utils.private_random_id() + await self._append_message_chunk("", chunk="start", stream_id=stream_id) + + try: + yield MessageStream(self, stream_id) + finally: + # Restore the checkpoint + self._message_stream_checkpoint = old_checkpoint + + # If this was the root stream, end it + if is_root_stream: + await self._append_message_chunk( + "", + chunk="end", + stream_id=stream_id, + ) + + async def _append_message_chunk( self, message: Any, *, - chunk: ChunkOption = False, - stream_id: str | None = None, + chunk: Literal[True, "start", "end"] = True, + stream_id: str, + operation: Literal["append", "replace"] = "append", icon: HTML | Tag | TagList | None = None, ) -> None: - # If currently we're in a stream, handle other messages (outside the stream) later - if not self._can_append_message(stream_id): - self._pending_messages.append((message, chunk, stream_id)) + # If currently we're in a *different* stream, queue the message chunk + if self._current_stream_id and self._current_stream_id != stream_id: + self._pending_messages.append((message, chunk, operation, stream_id)) return - # Update current stream state self._current_stream_id = stream_id - if chunk == "end": - self._current_stream_id = None - if chunk is False: - msg = normalize_message(message) - chunk_content = None - else: - msg = normalize_message_chunk(message) - # Update the current stream message - chunk_content = msg.content - self._current_stream_message += chunk_content + # Normalize various message types into a ChatMessage() + msg = normalize_message_chunk(message) + + if operation == "replace": + self._current_stream_message = self._message_stream_checkpoint + msg.content msg.content = self._current_stream_message + else: + self._current_stream_message += msg.content + + try: + if self._needs_transform(msg): + # Transforming may change the meaning of msg.content to be a *replace* + # not *append*. So, update msg.content and the operation accordingly. + chunk_content = msg.content + msg.content = self._current_stream_message + operation = "replace" + msg = await self._transform_message( + msg, chunk=chunk, chunk_content=chunk_content + ) + # Act like nothing happened if transformed to None + if msg is None: + return + if chunk == "end": + self._store_message(msg) + elif chunk == "end": + # When `operation="append"`, msg.content is just a chunk, but we must + # store the full message + self._store_message( + ChatMessage(content=self._current_stream_message, role=msg.role) + ) + + # Send the message to the client + await self._send_append_message( + message=msg, + chunk=chunk, + operation=operation, + icon=icon, + ) + finally: if chunk == "end": + self._current_stream_id = None self._current_stream_message = "" - - msg = await self._transform_message( - msg, chunk=chunk, chunk_content=chunk_content - ) - if msg is None: - return - self._store_message(msg, chunk=chunk) - await self._send_append_message( - msg, - chunk=chunk, - icon=icon, - ) + self._message_stream_checkpoint = "" async def append_message_stream( self, @@ -714,8 +836,8 @@ def latest_message_stream(self) -> reactive.ExtendedTask[[], str]: """ React to changes in the latest message stream. - Reactively reads for the :class:`~shiny.reactive.ExtendedTask` behind the - latest message stream. + Reactively reads for the :class:`~shiny.reactive.ExtendedTask` behind an + `.append_message_stream()`. From the return value (i.e., the extended task), you can then: @@ -746,37 +868,41 @@ async def _append_message_stream( id = _utils.private_random_id() empty = ChatMessageDict(content="", role="assistant") - await self._append_message(empty, chunk="start", stream_id=id, icon=icon) + await self._append_message_chunk(empty, chunk="start", stream_id=id, icon=icon) try: async for msg in message: - await self._append_message(msg, chunk=True, stream_id=id) + await self._append_message_chunk(msg, chunk=True, stream_id=id) return self._current_stream_message finally: - await self._append_message(empty, chunk="end", stream_id=id) + await self._append_message_chunk(empty, chunk="end", stream_id=id) await self._flush_pending_messages() async def _flush_pending_messages(self): - still_pending: list[PendingMessage] = [] - for msg, chunk, stream_id in self._pending_messages: - if self._can_append_message(stream_id): - await self._append_message(msg, chunk=chunk, stream_id=stream_id) + pending = self._pending_messages + self._pending_messages = [] + for msg, chunk, operation, stream_id in pending: + if chunk is False: + await self.append_message(msg) else: - still_pending.append((msg, chunk, stream_id)) - self._pending_messages = still_pending - - def _can_append_message(self, stream_id: str | None) -> bool: - if self._current_stream_id is None: - return True - return self._current_stream_id == stream_id + await self._append_message_chunk( + msg, + chunk=chunk, + operation=operation, + stream_id=cast(str, stream_id), + ) # Send a message to the UI async def _send_append_message( self, - message: TransformedMessage, + message: TransformedMessage | ChatMessage, chunk: ChunkOption = False, + operation: Literal["append", "replace"] = "append", icon: HTML | Tag | TagList | None = None, ): + if not isinstance(message, TransformedMessage): + message = TransformedMessage.from_chat_message(message) + if message.role == "system": # System messages are not displayed in the UI return @@ -801,6 +927,7 @@ async def _send_append_message( role=message.role, content_type=content_type, chunk_type=chunk_type, + operation=operation, ) if icon is not None: @@ -936,18 +1063,16 @@ async def _transform_message( self, message: ChatMessage, chunk: ChunkOption = False, - chunk_content: str | None = None, + chunk_content: str = "", ) -> TransformedMessage | None: res = TransformedMessage.from_chat_message(message) - key = res.transform_key if message.role == "user" and self._transform_user is not None: content = await self._transform_user(message.content) - elif message.role == "assistant" and self._transform_assistant is not None: content = await self._transform_assistant( message.content, - chunk_content or "", + chunk_content, chunk == "end" or chunk is False, ) else: @@ -956,20 +1081,25 @@ async def _transform_message( if content is None: return None - setattr(res, key, content) - + setattr(res, res.transform_key, content) return res + def _needs_transform(self, message: ChatMessage) -> bool: + if message.role == "user" and self._transform_user is not None: + return True + elif message.role == "assistant" and self._transform_assistant is not None: + return True + return False + # Just before storing, handle chunk msg type and calculate tokens def _store_message( self, - message: TransformedMessage, - chunk: ChunkOption = False, + message: TransformedMessage | ChatMessage, index: int | None = None, ) -> None: - # Don't actually store chunks until the end - if chunk is True or chunk == "start": - return None + + if not isinstance(message, TransformedMessage): + message = TransformedMessage.from_chat_message(message) with reactive.isolate(): messages = self._messages() @@ -1368,4 +1498,43 @@ def chat_ui( return res +class MessageStream: + """ + An object to yield from a `.message_stream_context()` context manager. + """ + + def __init__(self, chat: Chat, stream_id: str): + self._chat = chat + self._stream_id = stream_id + + async def replace(self, message_chunk: Any): + """ + Replace the content of the stream with new content. + + Parameters + ----------- + message_chunk + The new content to replace the current content. + """ + await self._chat._append_message_chunk( + message_chunk, + operation="replace", + stream_id=self._stream_id, + ) + + async def append(self, message_chunk: Any): + """ + Append a message chunk to the stream. + + Parameters + ----------- + message_chunk + A message chunk to append to this stream + """ + await self._chat._append_message_chunk( + message_chunk, + stream_id=self._stream_id, + ) + + CHAT_INSTANCES: WeakValueDictionary[str, Chat] = WeakValueDictionary() diff --git a/shiny/ui/_chat_types.py b/shiny/ui/_chat_types.py index 55bab4106..f5e7185f1 100644 --- a/shiny/ui/_chat_types.py +++ b/shiny/ui/_chat_types.py @@ -74,5 +74,6 @@ class ClientMessage(TypedDict): role: Literal["assistant", "user"] content_type: Literal["markdown", "html"] chunk_type: Literal["message_start", "message_end"] | None + operation: Literal["append", "replace"] icon: NotRequired[str] html_deps: NotRequired[list[dict[str, str]]] diff --git a/tests/playwright/shiny/components/chat/message_stream_context/app.py b/tests/playwright/shiny/components/chat/message_stream_context/app.py new file mode 100644 index 000000000..e8e9e7569 --- /dev/null +++ b/tests/playwright/shiny/components/chat/message_stream_context/app.py @@ -0,0 +1,116 @@ +import asyncio + +from shiny import reactive +from shiny.express import input, render, ui + +SLEEP_TIME = 0.25 + +ui.page_opts(title="Hello chat message streams") + +with ui.sidebar(style="height:100%"): + ui.input_action_button("stream_1", "Stream 1") + ui.input_action_button("stream_2", "Stream 2") + ui.input_action_button("stream_3", "Stream 3") + ui.input_action_button("stream_4", "Stream 4") + ui.input_action_button("stream_5", "Stream 5") + ui.input_action_button("stream_6", "Stream 6") + + ui.h6("Message state:", class_="mt-auto mb-0") + + @render.code + def message_state(): + return str(chat.messages()) + + +chat = ui.Chat(id="chat") +chat.ui() + + +@chat.on_user_submit +async def _(user_input: str): + await chat.append_message(f"You said: {user_input}") + + +@reactive.effect +@reactive.event(input.stream_1) +async def _(): + async with chat.message_stream_context() as msg: + await msg.append("Basic") + await asyncio.sleep(SLEEP_TIME) + await msg.append(" stream") + + +@reactive.effect +@reactive.event(input.stream_2) +async def _(): + async with chat.message_stream_context() as msg: + await msg.append("Basic") + await asyncio.sleep(SLEEP_TIME) + await msg.append(" stream") + await asyncio.sleep(SLEEP_TIME) + await msg.replace("Finished") + + +@reactive.effect +@reactive.event(input.stream_3) +async def _(): + async with chat.message_stream_context() as outer: + await outer.append("Outer start") + await asyncio.sleep(SLEEP_TIME) + async with chat.message_stream_context() as inner: + await inner.append("Inner start") + await asyncio.sleep(SLEEP_TIME) + await inner.append("Inner end") + await asyncio.sleep(SLEEP_TIME) + await outer.append("Outer end") + + +@reactive.effect +@reactive.event(input.stream_4) +async def _(): + async with chat.message_stream_context() as outer: + await outer.append("Outer start") + await asyncio.sleep(SLEEP_TIME) + async with chat.message_stream_context() as inner: + await inner.append("Inner start") + await asyncio.sleep(SLEEP_TIME) + await inner.replace("Inner end") + await asyncio.sleep(SLEEP_TIME) + await outer.append("Outer end") + + +@reactive.effect +@reactive.event(input.stream_5) +async def _(): + async with chat.message_stream_context() as outer: + await outer.append("Outer start") + await asyncio.sleep(SLEEP_TIME) + await outer.replace("") + async with chat.message_stream_context() as inner: + await inner.append("Inner start") + await asyncio.sleep(SLEEP_TIME) + await inner.append("Inner end") + await asyncio.sleep(SLEEP_TIME) + await outer.append("Outer end") + + +@reactive.effect +@reactive.event(input.stream_6) +async def _(): + await chat.append_message_stream(outer_stream()) + + +async def outer_stream(): + yield "Outer start" + await asyncio.sleep(SLEEP_TIME) + await inner_stream() + await asyncio.sleep(SLEEP_TIME) + yield "Outer end" + + +async def inner_stream(): + async with chat.message_stream_context() as stream: + await stream.append("Inner start") + await asyncio.sleep(SLEEP_TIME) + await stream.append("Inner progress") + await stream.replace("Inner end") diff --git a/tests/playwright/shiny/components/chat/message_stream_context/test_chat_message_stream_context.py b/tests/playwright/shiny/components/chat/message_stream_context/test_chat_message_stream_context.py new file mode 100644 index 000000000..12f8f6193 --- /dev/null +++ b/tests/playwright/shiny/components/chat/message_stream_context/test_chat_message_stream_context.py @@ -0,0 +1,94 @@ +from playwright.sync_api import Page, expect +from utils.deploy_utils import skip_on_webkit + +from shiny.playwright import controller +from shiny.run import ShinyAppProc + + +@skip_on_webkit +def test_validate_chat_message_stream_context( + page: Page, local_app: ShinyAppProc +) -> None: + page.goto(local_app.url) + + TIMEOUT = 30 * 1000 + + chat = controller.Chat(page, "chat") + expect(chat.loc).to_be_visible(timeout=TIMEOUT) + + stream_1 = controller.InputActionButton(page, "stream_1") + expect(stream_1.loc).to_be_visible(timeout=TIMEOUT) + stream_1.click() + + chat.expect_latest_message("Basic stream", timeout=TIMEOUT) + + stream_2 = controller.InputActionButton(page, "stream_2") + expect(stream_2.loc).to_be_visible(timeout=TIMEOUT) + stream_2.click() + + chat.expect_latest_message("Finished", timeout=TIMEOUT) + + chat.set_user_input("Hello") + chat.send_user_input() + chat.expect_latest_message("You said: Hello", timeout=TIMEOUT) + + stream_3 = controller.InputActionButton(page, "stream_3") + expect(stream_3.loc).to_be_visible(timeout=TIMEOUT) + stream_3.click() + + chat.expect_latest_message( + "Outer startInner startInner endOuter end", + timeout=TIMEOUT, + ) + + stream_4 = controller.InputActionButton(page, "stream_4") + expect(stream_4.loc).to_be_visible(timeout=TIMEOUT) + stream_4.click() + + chat.expect_latest_message( + "Outer startInner endOuter end", + timeout=TIMEOUT, + ) + + stream_5 = controller.InputActionButton(page, "stream_5") + expect(stream_5.loc).to_be_visible(timeout=TIMEOUT) + stream_5.click() + + chat.expect_latest_message( + "Inner startInner endOuter end", + timeout=TIMEOUT, + ) + + stream_6 = controller.InputActionButton(page, "stream_6") + expect(stream_6.loc).to_be_visible(timeout=TIMEOUT) + stream_6.click() + + chat.expect_latest_message( + "Outer startInner endOuter end", + timeout=TIMEOUT, + ) + + chat.set_user_input("Goodbye") + chat.send_user_input() + chat.expect_latest_message("You said: Goodbye", timeout=TIMEOUT) + + # Test server-side message state + message_state = controller.OutputCode(page, "message_state") + message_state_expected = tuple( + [ + {"content": "Basic stream", "role": "assistant"}, + {"content": "Finished", "role": "assistant"}, + {"content": "Hello", "role": "user"}, + {"content": "You said: Hello", "role": "assistant"}, + { + "content": "Outer startInner startInner endOuter end", + "role": "assistant", + }, + {"content": "Outer startInner endOuter end", "role": "assistant"}, + {"content": "Inner startInner endOuter end", "role": "assistant"}, + {"content": "Outer startInner endOuter end", "role": "assistant"}, + {"content": "Goodbye", "role": "user"}, + {"content": "You said: Goodbye", "role": "assistant"}, + ] + ) + message_state.expect_value(str(message_state_expected))