Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
916d2c0
feat(Chat): Add .start_message_stream(), .end_message_stream(), and .…
cpsievert Mar 12, 2025
4c32a0d
Go back to a more minimal change
cpsievert Mar 12, 2025
81a99d2
Cleanup
cpsievert Mar 12, 2025
08c104a
Polish API and clarify behavior
cpsievert Mar 12, 2025
edc30f0
Fix test
cpsievert Mar 12, 2025
28f148f
Merge branch 'main' into chat-inject-stream
cpsievert Mar 12, 2025
648d9cc
wip first pass at properly nested streams
cpsievert Mar 13, 2025
3786b79
Support nested streams and simplify logic
cpsievert Mar 13, 2025
260f902
.append_message() should also queue when a stream is active
cpsievert Mar 14, 2025
bd6d40f
Fix/simplify transform logic
cpsievert Mar 14, 2025
0c2e9ec
Reduce diff
cpsievert Mar 14, 2025
0cb99aa
Merge branch 'main' into chat-inject-stream
cpsievert Mar 14, 2025
a05b447
Update test
cpsievert Mar 14, 2025
bb171f7
Yield a MessageStream() instance with a .append() and .restore() method
cpsievert Mar 14, 2025
5e822fc
Cut public .append_message_chunk() method; rename context manager method
cpsievert Mar 17, 2025
24f05e5
Improve docstring
cpsievert Mar 17, 2025
4995ade
More complete playwright test
cpsievert Mar 17, 2025
3c0ca35
Merge branch 'main' into chat-inject-stream
cpsievert Mar 17, 2025
7ac2ba1
Rename append_message_context() -> message_stream_context()
cpsievert Mar 18, 2025
f8ff425
Rename .restore() -> .clear()
cpsievert Mar 19, 2025
f46aaf8
Merge branch 'main' into chat-inject-stream
cpsievert Mar 19, 2025
0ffdf21
Update changelog
cpsievert Mar 19, 2025
35837bc
Merge branch 'main' into chat-inject-stream
cpsievert Mar 19, 2025
c9fad38
Drop .clear() method in favor of replace
cpsievert Mar 20, 2025
bd5e610
Rename test files
cpsievert Mar 20, 2025
2a6d934
Include the operation when queueing message chunks
cpsievert Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 218 additions & 62 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncIterable,
Expand Down Expand Up @@ -197,10 +198,13 @@ def __init__(
self.on_error = on_error

# Chunked messages get accumulated (using this property) before changing state
self._current_stream_message = ""
self._current_stream_message: str = ""
self._current_stream_id: str | None = None
self._pending_messages: list[PendingMessage] = []

# For tracking message stream state when entering/exiting nested streams
self._message_stream_checkpoint: str = ""

# If a user input message is transformed into a response, we need to cancel
# the next user input submit handling
self._suspend_input_handler: bool = False
Expand Down Expand Up @@ -251,10 +255,7 @@ async def _on_user_input():
else:
# A transformed value of None is a special signal to suspend input
# handling (i.e., don't generate a response)
self._store_message(
TransformedMessage.from_chat_message(msg),
index=n_pre,
)
self._store_message(msg, index=n_pre)
await self._remove_loading_message()
self._suspend_input_handler = True

Expand Down Expand Up @@ -573,49 +574,165 @@ async def append_message(
similar) is specified in model's completion method.
:::
"""
await self._append_message(message, icon=icon)
# If we're in a stream, queue the message
if self._current_stream_id:
self._pending_messages.append((message, False, None))
return

msg = normalize_message(message)
msg = await self._transform_message(msg)
if msg is None:
return
self._store_message(msg)
await self._send_append_message(
message=msg,
chunk=False,
icon=icon,
)

async def _append_message(
@asynccontextmanager
async def message_stream_context(self):
"""
Message stream context manager.

A context manager for appending streaming messages into the chat. This context
manager can:

1. Be used in isolation to append a new streaming message to the chat.
* Compared to `.append_message_stream()` this method is more flexible but
isn't non-blocking by default (i.e., it doesn't launch an extended task).
2. Be nested within itself
* Nesting is primarily useful for making checkpoints to `.restore()` back
to (see the example below).
3. Be used from within a `.append_message_stream()`
* Useful for inserting additional content from another context into the
stream (e.g., see the note about tool calls below).

Yields
------
:
A `MessageStream` class instance, which has a method for `.append()`ing
message content chunks to as well as way to `.restore()` the stream back to
it's initial state. Note that `.append()` supports the same message content
types as `.append_message()`.

Example
-------
```python
import asyncio

from shiny import reactive
from shiny.express import ui

chat = ui.Chat(id="my_chat")
chat.ui()

@reactive.effect
async def _():
async with chat.message_stream_context() as msg:
await msg.append("Starting stream...\n\nProgress:")
async with chat.message_stream_context() as progress:
for x in [0, 50, 100]:
await progress.append(f" {x}%")
await asyncio.sleep(1)
await progress.restore()
await msg.restore()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're not going to use the term "checkpoint" then I think this should be msg.clear(). What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, that is better, thanks 👍

await msg.append("Completed stream")
```

Note
----
A useful pattern for displaying tool calls in a chatbot is for the tool to
display using `.message_stream_context()` while the the response generation is
happening through `.append_message_stream()`. This allows the tool to display
things like progress updates (or other "ephemeral" content) and optionally
`.restore()` the stream back to it's initial state when ready to display the
"final" content.
"""
# Checkpoint the current stream state so operation="replace" can return to it
old_checkpoint = self._message_stream_checkpoint
self._message_stream_checkpoint = self._current_stream_message

# No stream currently exists, start one
stream_id = self._current_stream_id
is_root_stream = stream_id is None
if is_root_stream:
stream_id = _utils.private_random_id()
await self._append_message_chunk("", chunk="start", stream_id=stream_id)

try:
yield MessageStream(self, stream_id)
finally:
# Restore the checkpoint
self._message_stream_checkpoint = old_checkpoint

# If this was the root stream, end it
if is_root_stream:
await self._append_message_chunk(
"",
chunk="end",
stream_id=stream_id,
)

async def _append_message_chunk(
self,
message: Any,
*,
chunk: ChunkOption = False,
stream_id: str | None = None,
chunk: Literal[True, "start", "end"] = True,
stream_id: str,
operation: Literal["append", "replace"] = "append",
icon: HTML | Tag | TagList | None = None,
) -> None:
# If currently we're in a stream, handle other messages (outside the stream) later
if not self._can_append_message(stream_id):
# If currently we're in a *different* stream, queue the message chunk
if self._current_stream_id and self._current_stream_id != stream_id:
self._pending_messages.append((message, chunk, stream_id))
return

# Update current stream state
self._current_stream_id = stream_id
if chunk == "end":
self._current_stream_id = None

if chunk is False:
msg = normalize_message(message)
chunk_content = None
else:
msg = normalize_message_chunk(message)
# Update the current stream message
chunk_content = msg.content
self._current_stream_message += chunk_content
# Normalize various message types into a ChatMessage()
msg = normalize_message_chunk(message)

if operation == "replace":
self._current_stream_message = self._message_stream_checkpoint + msg.content
msg.content = self._current_stream_message
else:
self._current_stream_message += msg.content

try:
if self._needs_transform(msg):
# Transforming may change the meaning of msg.content to be a *replace*
# not *append*. So, update msg.content and the operation accordingly.
chunk_content = msg.content
msg.content = self._current_stream_message
operation = "replace"
msg = await self._transform_message(
msg, chunk=chunk, chunk_content=chunk_content
)
# Act like nothing happened if transformed to None
if msg is None:
return
if chunk == "end":
self._store_message(msg)
elif chunk == "end":
# When `operation="append"`, msg.content is just a chunk, but we must
# store the full message
self._store_message(
ChatMessage(content=self._current_stream_message, role=msg.role)
)

# Send the message to the client
await self._send_append_message(
message=msg,
chunk=chunk,
operation=operation,
icon=icon,
)
finally:
if chunk == "end":
self._current_stream_id = None
self._current_stream_message = ""

msg = await self._transform_message(
msg, chunk=chunk, chunk_content=chunk_content
)
if msg is None:
return
self._store_message(msg, chunk=chunk)
await self._send_append_message(
msg,
chunk=chunk,
icon=icon,
)
self._message_stream_checkpoint = ""

async def append_message_stream(
self,
Expand Down Expand Up @@ -714,8 +831,8 @@ def latest_message_stream(self) -> reactive.ExtendedTask[[], str]:
"""
React to changes in the latest message stream.

Reactively reads for the :class:`~shiny.reactive.ExtendedTask` behind the
latest message stream.
Reactively reads for the :class:`~shiny.reactive.ExtendedTask` behind an
`.append_message_stream()`.

From the return value (i.e., the extended task), you can then:

Expand Down Expand Up @@ -746,37 +863,38 @@ async def _append_message_stream(
id = _utils.private_random_id()

empty = ChatMessageDict(content="", role="assistant")
await self._append_message(empty, chunk="start", stream_id=id, icon=icon)
await self._append_message_chunk(empty, chunk="start", stream_id=id, icon=icon)

try:
async for msg in message:
await self._append_message(msg, chunk=True, stream_id=id)
await self._append_message_chunk(msg, chunk=True, stream_id=id)
return self._current_stream_message
finally:
await self._append_message(empty, chunk="end", stream_id=id)
await self._append_message_chunk(empty, chunk="end", stream_id=id)
await self._flush_pending_messages()

async def _flush_pending_messages(self):
still_pending: list[PendingMessage] = []
for msg, chunk, stream_id in self._pending_messages:
if self._can_append_message(stream_id):
await self._append_message(msg, chunk=chunk, stream_id=stream_id)
pending = self._pending_messages
self._pending_messages = []
for msg, chunk, stream_id in pending:
if chunk is False:
await self.append_message(msg)
else:
still_pending.append((msg, chunk, stream_id))
self._pending_messages = still_pending

def _can_append_message(self, stream_id: str | None) -> bool:
if self._current_stream_id is None:
return True
return self._current_stream_id == stream_id
await self._append_message_chunk(
msg, chunk=chunk, stream_id=cast(str, stream_id)
)

# Send a message to the UI
async def _send_append_message(
self,
message: TransformedMessage,
message: TransformedMessage | ChatMessage,
chunk: ChunkOption = False,
operation: Literal["append", "replace"] = "append",
icon: HTML | Tag | TagList | None = None,
):
if not isinstance(message, TransformedMessage):
message = TransformedMessage.from_chat_message(message)

if message.role == "system":
# System messages are not displayed in the UI
return
Expand All @@ -801,6 +919,7 @@ async def _send_append_message(
role=message.role,
content_type=content_type,
chunk_type=chunk_type,
operation=operation,
)

if icon is not None:
Expand Down Expand Up @@ -936,18 +1055,16 @@ async def _transform_message(
self,
message: ChatMessage,
chunk: ChunkOption = False,
chunk_content: str | None = None,
chunk_content: str = "",
) -> TransformedMessage | None:
res = TransformedMessage.from_chat_message(message)
key = res.transform_key

if message.role == "user" and self._transform_user is not None:
content = await self._transform_user(message.content)

elif message.role == "assistant" and self._transform_assistant is not None:
content = await self._transform_assistant(
message.content,
chunk_content or "",
chunk_content,
chunk == "end" or chunk is False,
)
else:
Expand All @@ -956,20 +1073,25 @@ async def _transform_message(
if content is None:
return None

setattr(res, key, content)

setattr(res, res.transform_key, content)
return res

def _needs_transform(self, message: ChatMessage) -> bool:
if message.role == "user" and self._transform_user is not None:
return True
elif message.role == "assistant" and self._transform_assistant is not None:
return True
return False

# Just before storing, handle chunk msg type and calculate tokens
def _store_message(
self,
message: TransformedMessage,
chunk: ChunkOption = False,
message: TransformedMessage | ChatMessage,
index: int | None = None,
) -> None:
# Don't actually store chunks until the end
if chunk is True or chunk == "start":
return None

if not isinstance(message, TransformedMessage):
message = TransformedMessage.from_chat_message(message)

with reactive.isolate():
messages = self._messages()
Expand Down Expand Up @@ -1368,4 +1490,38 @@ def chat_ui(
return res


class MessageStream:
"""
An object to yield from a `.message_stream_context()` context manager.
"""

def __init__(self, chat: Chat, stream_id: str):
self._chat = chat
self._stream_id = stream_id

async def restore(self):
"""
Restore the stream back to its initial state.
"""
await self._chat._append_message_chunk(
"",
operation="replace",
stream_id=self._stream_id,
)

async def append(self, message_chunk: Any):
"""
Append a message chunk to the stream.

Parameters
-----------
message_chunk
A message chunk to append to this stream
"""
await self._chat._append_message_chunk(
message_chunk,
stream_id=self._stream_id,
)


CHAT_INSTANCES: WeakValueDictionary[str, Chat] = WeakValueDictionary()
Loading
Loading