Skip to content

Commit bb171f7

Browse files
committed
Yield a MessageStream() instance with a .append() and .restore() method
1 parent a05b447 commit bb171f7

File tree

2 files changed

+56
-23
lines changed
  • shiny/ui
  • tests/playwright/shiny/components/chat/message-stream

2 files changed

+56
-23
lines changed

shiny/ui/_chat.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -590,12 +590,7 @@ async def append_message(
590590
icon=icon,
591591
)
592592

593-
async def append_message_chunk(
594-
self,
595-
message_chunk: Any,
596-
*,
597-
operation: Literal["append", "replace"] = "append",
598-
):
593+
async def append_message_chunk(self, message_chunk: Any):
599594
"""
600595
Append a message chunk to the current message stream.
601596
@@ -606,8 +601,6 @@ async def append_message_chunk(
606601
----------
607602
message_chunk
608603
A message chunk to inject.
609-
operation
610-
Whether to append or replace the *current* message stream content.
611604
612605
Note
613606
----
@@ -633,7 +626,6 @@ async def append_message_chunk(
633626
return await self._append_message_chunk(
634627
message_chunk,
635628
stream_id=stream_id,
636-
operation=operation,
637629
)
638630

639631
@asynccontextmanager
@@ -644,6 +636,12 @@ async def message_stream(self):
644636
A context manager for streaming messages into the chat. Note this stream
645637
can occur within a longer running `.append_message_stream()` or used on its own.
646638
639+
Yields
640+
------
641+
:
642+
A `MessageStream` instance with a method for `.append()`ing message chunks
643+
and a method for `.restore()`ing the stream back to it's initial state.
644+
647645
Note
648646
----
649647
A useful pattern for displaying tool calls in a chat interface is for the
@@ -658,16 +656,14 @@ async def message_stream(self):
658656
self._message_stream_checkpoint = self._current_stream_message
659657

660658
# No stream currently exists, start one
661-
is_root_stream = not self._current_stream_id
659+
stream_id = self._current_stream_id
660+
is_root_stream = stream_id is None
662661
if is_root_stream:
663-
await self._append_message_chunk(
664-
"",
665-
chunk="start",
666-
stream_id=_utils.private_random_id(),
667-
)
662+
stream_id = _utils.private_random_id()
663+
await self._append_message_chunk("", chunk="start", stream_id=stream_id)
668664

669665
try:
670-
yield
666+
yield MessageStream(self, stream_id)
671667
finally:
672668
# Restore the previous stream state
673669
self._message_stream_checkpoint = old_checkpoint
@@ -677,7 +673,7 @@ async def message_stream(self):
677673
await self._append_message_chunk(
678674
"",
679675
chunk="end",
680-
stream_id=cast(str, self._current_stream_id),
676+
stream_id=stream_id,
681677
)
682678

683679
async def _append_message_chunk(
@@ -1496,4 +1492,36 @@ def chat_ui(
14961492
return res
14971493

14981494

1495+
class MessageStream:
1496+
""""""
1497+
1498+
def __init__(self, chat: Chat, stream_id: str):
1499+
self._chat = chat
1500+
self._stream_id = stream_id
1501+
1502+
async def restore(self):
1503+
"""
1504+
Restore the stream back to its initial state.
1505+
"""
1506+
await self._chat._append_message_chunk(
1507+
"",
1508+
operation="replace",
1509+
stream_id=self._stream_id,
1510+
)
1511+
1512+
async def append(self, message_chunk: Any):
1513+
"""
1514+
Append a message chunk to the stream.
1515+
1516+
Parameters
1517+
-----------
1518+
message_chunk
1519+
A message chunk to append to this stream
1520+
"""
1521+
await self._chat._append_message_chunk(
1522+
message_chunk,
1523+
stream_id=self._stream_id,
1524+
)
1525+
1526+
14991527
CHAT_INSTANCES: WeakValueDictionary[str, Chat] = WeakValueDictionary()

tests/playwright/shiny/components/chat/message-stream/app.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from shiny import reactive
44
from shiny.express import input, render, ui
55

6-
SLEEP_TIME = 0.75
6+
SLEEP_TIME = 0.25
77

88
ui.page_opts(title="Hello chat message streams")
99

@@ -37,11 +37,12 @@ async def _():
3737
" Progress: 50%",
3838
" Progress: 100%",
3939
]
40-
async with chat.message_stream():
40+
async with chat.message_stream() as stream:
4141
for chunk in chunks:
42-
await chat.append_message_chunk(chunk)
42+
await stream.append(chunk)
4343
await asyncio.sleep(SLEEP_TIME)
44-
await chat.append_message_chunk("Completed stream 1 ✅", operation="replace")
44+
await stream.restore()
45+
await stream.append("Completed stream 1 ✅")
4546

4647

4748
# TODO: add test here for nested .message_stream()
@@ -68,8 +69,12 @@ async def mock_tool():
6869
" Progress: 50%",
6970
" Progress: 100%",
7071
]
71-
for chunk in chunks:
72-
await chat.append_message_chunk(chunk, operation="replace")
72+
async with chat.message_stream() as stream:
73+
for chunk in chunks:
74+
await stream.append(chunk)
75+
await asyncio.sleep(SLEEP_TIME)
76+
await stream.restore()
77+
await stream.append("Completed inner stream ✅")
7378

7479

7580
# TODO: more tests, like submitting input, etc.

0 commit comments

Comments
 (0)