Skip to content

Commit 3786b79

Browse files
committed
Support nested streams and simplify logic
1 parent 648d9cc commit 3786b79

File tree

2 files changed

+122
-135
lines changed

2 files changed

+122
-135
lines changed

shiny/ui/_chat.py

Lines changed: 89 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import inspect
4-
import warnings
54
from contextlib import asynccontextmanager
65
from typing import (
76
Any,
@@ -81,9 +80,7 @@
8180
UserSubmitFunction1,
8281
]
8382

84-
ChunkOption = Literal["start", "end", True, False]
85-
86-
PendingMessage = Tuple[Any, ChunkOption, Union[str, None]]
83+
PendingMessage = Tuple[Any, Literal["start", "end", True], Union[str, None]]
8784

8885

8986
@add_example(ex_dir="../templates/chat/starters/hello")
@@ -199,15 +196,12 @@ def __init__(
199196
self.on_error = on_error
200197

201198
# Chunked messages get accumulated (using this property) before changing state
202-
self._current_stream_message = ""
199+
self._current_stream_message: str = ""
203200
self._current_stream_id: str | None = None
204201
self._pending_messages: list[PendingMessage] = []
205202

206-
# Identifier for a manual stream (i.e., one started with `.start_message_stream()`)
207-
self._manual_stream_id: str | None = None
208-
# If a manual stream gets nested within another stream, we need to keep track of
209-
# the accumulated message separately
210-
self._nested_stream_message: str = ""
203+
# For tracking message stream state when entering/exiting nested streams
204+
self._message_stream_checkpoint: str = ""
211205

212206
# If a user input message is transformed into a response, we need to cancel
213207
# the next user input submit handling
@@ -576,7 +570,16 @@ async def append_message(
576570
similar) is specified in model's completion method.
577571
:::
578572
"""
579-
await self._append_message(message, icon=icon)
573+
msg = normalize_message(message)
574+
msg = await self._transform_message(msg)
575+
if msg is None:
576+
return
577+
self._store_message(msg)
578+
await self._send_append_message(
579+
message=msg,
580+
chunk=False,
581+
icon=icon,
582+
)
580583

581584
async def append_message_chunk(
582585
self,
@@ -618,9 +621,8 @@ async def append_message_chunk(
618621
"Use .message_stream() or .append_message_stream() to start one."
619622
)
620623

621-
return await self._append_message(
624+
return await self._append_message_chunk(
622625
message_chunk,
623-
chunk=True,
624626
stream_id=stream_id,
625627
operation=operation,
626628
)
@@ -641,75 +643,39 @@ async def message_stream(self):
641643
to display "ephemeral" content, then eventually show a final state
642644
with `.append_message_chunk(operation="replace")`.
643645
"""
644-
await self._start_stream()
646+
# Save the current stream state in a checkpoint (so that we can handle
647+
# ``.append_message_chunk(operation="replace")` correctly)
648+
old_checkpoint = self._message_stream_checkpoint
649+
self._message_stream_checkpoint = self._current_stream_message
650+
651+
# No stream currently exists, start one
652+
is_root_stream = not self._current_stream_id
653+
if is_root_stream:
654+
await self._append_message_chunk(
655+
"",
656+
chunk="start",
657+
stream_id=_utils.private_random_id(),
658+
)
659+
645660
try:
646661
yield
647662
finally:
648-
await self._end_stream()
649-
650-
async def _start_stream(self):
651-
if self._manual_stream_id is not None:
652-
# TODO: support this?
653-
raise ValueError("Nested .message_stream() isn't currently supported.")
654-
# If we're currently streaming (i.e., through append_message_stream()), then
655-
# end the client message stream (since we start a new one below)
656-
if self._current_stream_id is not None:
657-
await self._send_append_message(
658-
message=ChatMessage(content="", role="assistant"),
659-
chunk="end",
660-
operation="append",
661-
)
662-
# Regardless whether this is an "inner" stream, we start a new message on the
663-
# client so it can handle `operation="replace"` without having to track where
664-
# the inner stream started.
665-
self._manual_stream_id = _utils.private_random_id()
666-
stream_id = self._current_stream_id or self._manual_stream_id
667-
return await self._append_message(
668-
"",
669-
chunk="start",
670-
stream_id=stream_id,
671-
# TODO: find a cleaner way to do this, and remove the gap between the messages
672-
icon=(
673-
HTML("<span class='border-0'><span>")
674-
if self._is_nested_stream
675-
else None
676-
),
677-
)
678-
679-
async def _end_stream(self):
680-
if self._manual_stream_id is None and self._current_stream_id is None:
681-
warnings.warn(
682-
"Tried to end a message stream, but one isn't currently active.",
683-
stacklevel=2,
684-
)
685-
return
686-
687-
if self._is_nested_stream:
688-
# If inside another stream, just update server-side message state
689-
self._current_stream_message += self._nested_stream_message
690-
self._nested_stream_message = ""
691-
else:
692-
# Otherwise, end this "manual" message stream
693-
await self._append_message(
694-
"", chunk="end", stream_id=self._manual_stream_id
695-
)
696-
697-
self._manual_stream_id = None
698-
return
699-
700-
@property
701-
def _is_nested_stream(self):
702-
return (
703-
self._current_stream_id is not None
704-
and self._manual_stream_id is not None
705-
and self._current_stream_id != self._manual_stream_id
706-
)
663+
# Restore the previous stream state
664+
self._message_stream_checkpoint = old_checkpoint
665+
666+
# If this was the root stream, end it
667+
if is_root_stream:
668+
await self._append_message_chunk(
669+
"",
670+
chunk="end",
671+
stream_id=self._current_stream_id,
672+
)
707673

708-
async def _append_message(
674+
async def _append_message_chunk(
709675
self,
710676
message: Any,
711677
*,
712-
chunk: ChunkOption = False,
678+
chunk: Literal[True, "start", "end"] = True,
713679
operation: Literal["append", "replace"] = "append",
714680
stream_id: str | None = None,
715681
icon: HTML | Tag | TagList | None = None,
@@ -724,37 +690,40 @@ async def _append_message(
724690
if chunk == "end":
725691
self._current_stream_id = None
726692

727-
if chunk is False:
728-
msg = normalize_message(message)
693+
# Normalize into a ChatMessage()
694+
msg = normalize_message_chunk(message)
695+
696+
# Remember this content chunk for passing to transformer
697+
this_chunk = msg.content
698+
699+
# Transforming requires replacing
700+
if self._needs_transform(msg):
701+
operation = "replace"
702+
703+
if operation == "replace":
704+
# Replace up to the latest checkpoint
705+
self._current_stream_message = self._message_stream_checkpoint + this_chunk
706+
msg.content = self._current_stream_message
729707
else:
730-
msg = normalize_message_chunk(message)
731-
if self._is_nested_stream:
732-
if operation == "replace":
733-
self._nested_stream_message = ""
734-
self._nested_stream_message += msg.content
735-
else:
736-
if operation == "replace":
737-
self._current_stream_message = ""
738-
self._current_stream_message += msg.content
708+
self._current_stream_message += msg.content
739709

740710
try:
741-
msg = await self._transform_message(msg, chunk=chunk)
742-
# Act like nothing happened if transformed to None
743-
if msg is None:
744-
return
745-
msg_store = msg
746-
# Transforming requires *replacing* content
747-
if isinstance(msg, TransformedMessage):
748-
operation = "replace"
711+
if self._needs_transform(msg):
712+
msg = await self._transform_message(
713+
msg, chunk=chunk, chunk_content=this_chunk
714+
)
715+
# Act like nothing happened if transformed to None
716+
if msg is None:
717+
return
718+
if chunk == "end":
719+
self._store_message(msg)
749720
elif chunk == "end":
750-
# When not transforming, ensure full message is stored
751-
msg_store = ChatMessage(
752-
content=self._current_stream_message,
753-
role="assistant",
721+
# When `operation="append"`, msg.content is just a chunk, but we must
722+
# store the full message
723+
self._store_message(
724+
ChatMessage(content=self._current_stream_message, role=msg.role)
754725
)
755-
# Only store full messages
756-
if chunk is False or chunk == "end":
757-
self._store_message(msg_store)
726+
758727
# Send the message to the client
759728
await self._send_append_message(
760729
message=msg,
@@ -764,10 +733,8 @@ async def _append_message(
764733
)
765734
finally:
766735
if chunk == "end":
767-
if self._is_nested_stream:
768-
self._nested_stream_message = ""
769-
else:
770-
self._current_stream_message = ""
736+
self._current_stream_message = ""
737+
self._message_stream_checkpoint = ""
771738

772739
async def append_message_stream(
773740
self,
@@ -898,21 +865,21 @@ async def _append_message_stream(
898865
id = _utils.private_random_id()
899866

900867
empty = ChatMessageDict(content="", role="assistant")
901-
await self._append_message(empty, chunk="start", stream_id=id, icon=icon)
868+
await self._append_message_chunk(empty, chunk="start", stream_id=id, icon=icon)
902869

903870
try:
904871
async for msg in message:
905-
await self._append_message(msg, chunk=True, stream_id=id)
872+
await self._append_message_chunk(msg, chunk=True, stream_id=id)
906873
return self._current_stream_message
907874
finally:
908-
await self._append_message(empty, chunk="end", stream_id=id)
875+
await self._append_message_chunk(empty, chunk="end", stream_id=id)
909876
await self._flush_pending_messages()
910877

911878
async def _flush_pending_messages(self):
912879
still_pending: list[PendingMessage] = []
913880
for msg, chunk, stream_id in self._pending_messages:
914881
if self._can_append_message(stream_id):
915-
await self._append_message(msg, chunk=chunk, stream_id=stream_id)
882+
await self._append_message_chunk(msg, chunk=chunk, stream_id=stream_id)
916883
else:
917884
still_pending.append((msg, chunk, stream_id))
918885
self._pending_messages = still_pending
@@ -926,7 +893,7 @@ def _can_append_message(self, stream_id: str | None) -> bool:
926893
async def _send_append_message(
927894
self,
928895
message: TransformedMessage | ChatMessage,
929-
chunk: ChunkOption = False,
896+
chunk: Literal["start", "end", True, False] = False,
930897
operation: Literal["append", "replace"] = "append",
931898
icon: HTML | Tag | TagList | None = None,
932899
):
@@ -1092,32 +1059,35 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
10921059
async def _transform_message(
10931060
self,
10941061
message: ChatMessage,
1095-
chunk: ChunkOption = False,
1096-
) -> ChatMessage | TransformedMessage | None:
1062+
chunk: Literal["start", "end", True, False] = False,
1063+
chunk_content: str = "",
1064+
) -> TransformedMessage | None:
10971065
res = TransformedMessage.from_chat_message(message)
10981066

10991067
if message.role == "user" and self._transform_user is not None:
11001068
content = await self._transform_user(message.content)
11011069
elif message.role == "assistant" and self._transform_assistant is not None:
1102-
all_content = (
1103-
message.content if chunk is False else self._current_stream_message
1104-
)
1105-
setattr(res, res.pre_transform_key, all_content)
11061070
content = await self._transform_assistant(
1107-
all_content,
11081071
message.content,
1072+
chunk_content,
11091073
chunk == "end" or chunk is False,
11101074
)
11111075
else:
1112-
return message
1076+
return res
11131077

11141078
if content is None:
11151079
return None
11161080

11171081
setattr(res, res.transform_key, content)
1118-
11191082
return res
11201083

1084+
def _needs_transform(self, message: ChatMessage) -> bool:
1085+
if message.role == "user" and self._transform_user is not None:
1086+
return True
1087+
elif message.role == "assistant" and self._transform_assistant is not None:
1088+
return True
1089+
return False
1090+
11211091
# Just before storing, handle chunk msg type and calculate tokens
11221092
def _store_message(
11231093
self,

0 commit comments

Comments
 (0)