Skip to content

Commit 648d9cc

Browse files
committed
wip first pass at properly nested streams
1 parent 28f148f commit 648d9cc

File tree

2 files changed

+143
-65
lines changed

2 files changed

+143
-65
lines changed

shiny/ui/_chat.py

Lines changed: 106 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import inspect
44
import warnings
5+
from contextlib import asynccontextmanager
56
from typing import (
67
Any,
78
AsyncIterable,
@@ -202,7 +203,11 @@ def __init__(
202203
self._current_stream_id: str | None = None
203204
self._pending_messages: list[PendingMessage] = []
204205

206+
# Identifier for a manual stream (i.e., one started with `.start_message_stream()`)
205207
self._manual_stream_id: str | None = None
208+
# If a manual stream gets nested within another stream, we need to keep track of
209+
# the accumulated message separately
210+
self._nested_stream_message: str = ""
206211

207212
# If a user input message is transformed into a response, we need to cancel
208213
# the next user input submit handling
@@ -578,34 +583,40 @@ async def append_message_chunk(
578583
message_chunk: Any,
579584
*,
580585
operation: Literal["append", "replace"] = "append",
581-
force: bool = False,
582586
):
583587
"""
584588
Append a message chunk to the current message stream.
585589
586-
Append a chunk of message content to either the currently running
587-
`.append_message_stream()` or to one that was manually started with
588-
`.start_message_stream()`.
590+
Append a chunk of message content to either a stream started with
591+
`.message_stream()` or an active `.append_message_stream()`.
589592
590593
Parameters
591594
----------
592595
message_chunk
593596
A message chunk to inject.
594597
operation
595-
Whether to append or replace the current message stream content.
596-
force
597-
Whether to start a new stream if one is not currently active.
598+
Whether to append or replace the *current* message stream content.
599+
600+
Note
601+
----
602+
A useful pattern for displaying tool calls in a chat is for the tools to display
603+
content using an "inner" `.message_stream()` while the response generation is
604+
happening in an "outer" `.append_message_stream()`. This allows the inner stream
605+
to display "ephemeral" content, then eventually show a final state with
606+
`.append_message_chunk(operation="replace")`.
607+
608+
Raises
609+
------
610+
ValueError
611+
If there is active stream (i.e., no `.message_stream()` or
612+
`.append_message_stream()`)
598613
"""
599-
# Can append to either an active `.start_message_stream()` or a
600-
# # `.append_message_stream()`
601-
stream_id = self._manual_stream_id or self._current_stream_id
614+
stream_id = self._current_stream_id
602615
if stream_id is None:
603-
if not force:
604-
raise ValueError(
605-
"Can't append a message chunk without an active message stream. "
606-
"Use `force=True` to start a new message stream if one is not currently active.",
607-
)
608-
await self.start_message_stream()
616+
raise ValueError(
617+
"Can't .append_message_chunk() without an active message stream. "
618+
"Use .message_stream() or .append_message_stream() to start one."
619+
)
609620

610621
return await self._append_message(
611622
message_chunk,
@@ -614,40 +625,84 @@ async def append_message_chunk(
614625
operation=operation,
615626
)
616627

617-
async def start_message_stream(self):
628+
@asynccontextmanager
629+
async def message_stream(self):
618630
"""
619-
Start a new message stream.
631+
Message stream context manager.
632+
633+
A context manager for streaming messages into the chat. Note this stream
634+
can occur within a longer running `.append_message_stream()` or used on its own.
620635
621-
Starts a new message stream which can then be appended to using
622-
`.append_message_chunk()`.
636+
Note
637+
----
638+
A useful pattern for displaying tool calls in a chat interface is for the
639+
tool to display using `.message_stream()` while the the response generation
640+
is happening through `.append_message_stream()`. This allows the inner stream
641+
to display "ephemeral" content, then eventually show a final state
642+
with `.append_message_chunk(operation="replace")`.
623643
"""
624-
# Since `._append_message()` manages a queue of message streams, we can just
625-
# start a new stream here. Note that, if a stream is already active, this
626-
# stream should start once the current stream ends.
627-
stream_id = _utils.private_random_id()
628-
# Separately track the stream id so ``.append_message_chunk()``/`.end_message_stream()`
629-
self._manual_stream_id = stream_id
644+
await self._start_stream()
645+
try:
646+
yield
647+
finally:
648+
await self._end_stream()
649+
650+
async def _start_stream(self):
651+
if self._manual_stream_id is not None:
652+
# TODO: support this?
653+
raise ValueError("Nested .message_stream() isn't currently supported.")
654+
# If we're currently streaming (i.e., through append_message_stream()), then
655+
# end the client message stream (since we start a new one below)
656+
if self._current_stream_id is not None:
657+
await self._send_append_message(
658+
message=ChatMessage(content="", role="assistant"),
659+
chunk="end",
660+
operation="append",
661+
)
662+
# Regardless whether this is an "inner" stream, we start a new message on the
663+
# client so it can handle `operation="replace"` without having to track where
664+
# the inner stream started.
665+
self._manual_stream_id = _utils.private_random_id()
666+
stream_id = self._current_stream_id or self._manual_stream_id
630667
return await self._append_message(
631668
"",
632669
chunk="start",
633670
stream_id=stream_id,
671+
# TODO: find a cleaner way to do this, and remove the gap between the messages
672+
icon=(
673+
HTML("<span class='border-0'><span>")
674+
if self._is_nested_stream
675+
else None
676+
),
634677
)
635678

636-
async def end_message_stream(self):
637-
"""
638-
End the current message stream (if any).
639-
640-
Ends a message stream that was started with `.start_message_stream()`.
641-
"""
642-
stream_id = self._manual_stream_id
643-
if stream_id is None:
644-
warnings.warn("No currently active stream to end.", stacklevel=2)
679+
async def _end_stream(self):
680+
if self._manual_stream_id is None and self._current_stream_id is None:
681+
warnings.warn(
682+
"Tried to end a message stream, but one isn't currently active.",
683+
stacklevel=2,
684+
)
645685
return
646686

647-
return await self._append_message(
648-
"",
649-
chunk="end",
650-
stream_id=stream_id,
687+
if self._is_nested_stream:
688+
# If inside another stream, just update server-side message state
689+
self._current_stream_message += self._nested_stream_message
690+
self._nested_stream_message = ""
691+
else:
692+
# Otherwise, end this "manual" message stream
693+
await self._append_message(
694+
"", chunk="end", stream_id=self._manual_stream_id
695+
)
696+
697+
self._manual_stream_id = None
698+
return
699+
700+
@property
701+
def _is_nested_stream(self):
702+
return (
703+
self._current_stream_id is not None
704+
and self._manual_stream_id is not None
705+
and self._current_stream_id != self._manual_stream_id
651706
)
652707

653708
async def _append_message(
@@ -673,9 +728,14 @@ async def _append_message(
673728
msg = normalize_message(message)
674729
else:
675730
msg = normalize_message_chunk(message)
676-
if operation == "replace":
677-
self._current_stream_message = ""
678-
self._current_stream_message += msg.content
731+
if self._is_nested_stream:
732+
if operation == "replace":
733+
self._nested_stream_message = ""
734+
self._nested_stream_message += msg.content
735+
else:
736+
if operation == "replace":
737+
self._current_stream_message = ""
738+
self._current_stream_message += msg.content
679739

680740
try:
681741
msg = await self._transform_message(msg, chunk=chunk)
@@ -704,7 +764,10 @@ async def _append_message(
704764
)
705765
finally:
706766
if chunk == "end":
707-
self._current_stream_message = ""
767+
if self._is_nested_stream:
768+
self._nested_stream_message = ""
769+
else:
770+
self._current_stream_message = ""
708771

709772
async def append_message_stream(
710773
self,

tests/playwright/shiny/components/chat/inject/app.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,57 @@
99
chat.ui()
1010

1111

12-
async def generator():
13-
yield "Starting stream..."
14-
await asyncio.sleep(0.5)
15-
yield "...stream complete"
16-
17-
12+
# Launch a stream on load
1813
@reactive.effect
1914
async def _():
20-
await chat.append_message_stream(generator())
15+
await chat.append_message_stream(mock_stream())
2116

2217

23-
@reactive.effect
24-
async def _():
25-
await chat.append_message_chunk("injected chunk")
18+
async def mock_stream():
19+
yield "Starting outer stream...\n\n"
20+
await asyncio.sleep(0.5)
21+
await mock_tool()
22+
await asyncio.sleep(0.5)
23+
yield "\n\n...outer stream complete"
2624

2725

28-
ui.input_action_button("run_test", "Run test")
26+
# While the "outer" `.append_message_stream()` is running,
27+
# start an "inner" stream with .message_stream()
28+
async def mock_tool():
29+
steps = [
30+
"Starting inner stream 🔄...\n\n",
31+
"Progress: 0%...",
32+
"Progress: 50%...",
33+
"Progress: 100%...",
34+
]
35+
async with chat.message_stream():
36+
for chunk in steps:
37+
await chat.append_message_chunk(chunk)
38+
await asyncio.sleep(0.5)
39+
await chat.append_message_chunk(
40+
"Completed inner stream ✅",
41+
operation="replace",
42+
)
2943

3044

31-
@reactive.effect
32-
@reactive.event(input.run_test)
33-
async def _():
34-
await chat.start_message_stream()
35-
for chunk in ["can ", "inject ", "chunks"]:
36-
await asyncio.sleep(0.2)
37-
await chat.append_message_chunk(chunk)
38-
await chat.end_message_stream()
45+
@chat.on_user_submit
46+
async def _(user_input: str):
47+
await chat.append_message_stream(f"You said: {user_input}")
3948

4049

41-
ui.input_action_button("run_test2", "Run test 2")
50+
ui.input_action_button("add_stream_basic", "Add .message_stream()")
4251

4352

4453
@reactive.effect
45-
@reactive.event(input.run_test2)
54+
@reactive.event(input.add_stream_basic)
4655
async def _():
47-
await chat.append_message_stream(["can ", "append ", "chunks"])
56+
async with chat.message_stream():
57+
await chat.append_message_chunk("Running test...")
58+
await asyncio.sleep(1)
59+
await chat.append_message_chunk("Test complete!")
60+
61+
62+
# TODO: more tests, like submitting input, etc.
4863

4964

5065
@render.code

0 commit comments

Comments
 (0)