Skip to content

Commit bd6d40f

Browse files
committed
Fix/simplify transform logic
1 parent 260f902 commit bd6d40f

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

shiny/ui/_chat.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -692,32 +692,26 @@ async def _append_message_chunk(
692692
self._pending_messages.append((message, chunk, stream_id))
693693
return
694694

695-
# Update current stream state
696695
self._current_stream_id = stream_id
697-
if chunk == "end":
698-
self._current_stream_id = None
699696

700-
# Normalize into a ChatMessage()
697+
# Normalize various message types into a ChatMessage()
701698
msg = normalize_message_chunk(message)
702699

703-
# Remember this content chunk for passing to transformer
704-
this_chunk = msg.content
705-
706-
# Transforming requires replacing
707-
if self._needs_transform(msg):
708-
operation = "replace"
709-
710700
if operation == "replace":
711-
# Replace up to the latest checkpoint
712-
self._current_stream_message = self._message_stream_checkpoint + this_chunk
701+
self._current_stream_message = self._message_stream_checkpoint + msg.content
713702
msg.content = self._current_stream_message
714703
else:
715704
self._current_stream_message += msg.content
716705

717706
try:
718707
if self._needs_transform(msg):
708+
# Transforming may change the meaning of msg.content to be a *replace*
709+
# not *append*. So, update msg.content and the operation accordingly.
710+
chunk_content = msg.content
711+
msg.content = self._current_stream_message
712+
operation = "replace"
719713
msg = await self._transform_message(
720-
msg, chunk=chunk, chunk_content=this_chunk
714+
msg, chunk=chunk, chunk_content=chunk_content
721715
)
722716
# Act like nothing happened if transformed to None
723717
if msg is None:
@@ -740,6 +734,7 @@ async def _append_message_chunk(
740734
)
741735
finally:
742736
if chunk == "end":
737+
self._current_stream_id = None
743738
self._current_stream_message = ""
744739
self._message_stream_checkpoint = ""
745740

@@ -1063,7 +1058,7 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
10631058
async def _transform_message(
10641059
self,
10651060
message: ChatMessage,
1066-
chunk: Literal["start", "end", True, False] = False,
1061+
chunk: ChunkOption = False,
10671062
chunk_content: str = "",
10681063
) -> TransformedMessage | None:
10691064
res = TransformedMessage.from_chat_message(message)

0 commit comments

Comments
 (0)