Skip to content

Commit 81a99d2

Browse files
committed
Cleanup
1 parent 4c32a0d commit 81a99d2

File tree

1 file changed

+36
-39
lines changed

1 file changed

+36
-39
lines changed

shiny/ui/_chat.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -236,16 +236,14 @@ async def _on_user_input():
236236
# It's possible that during the transform, a message is appended, so get
237237
# the length now, so we can insert the new message at the right index
238238
n_pre = len(self._messages())
239-
msg_post, _ = await self._transform_message(msg)
239+
msg_post = await self._transform_message(msg)
240240
if msg_post is not None:
241241
self._store_message(msg_post)
242242
self._suspend_input_handler = False
243243
else:
244244
# A transformed value of None is a special signal to suspend input
245245
# handling (i.e., don't generate a response)
246-
self._store_message(
247-
TransformedMessage.from_message(msg), index=n_pre
248-
)
246+
self._store_message(msg, index=n_pre)
249247
await self._remove_loading_message()
250248
self._suspend_input_handler = True
251249

@@ -659,31 +657,28 @@ async def _append_message(
659657
self._current_stream_message += msg["content"]
660658

661659
try:
662-
msg_t, transformed = await self._transform_message(msg, chunk=chunk)
663-
# Act like nothing happened if content transformed to None
664-
if msg_t is None:
660+
msg = await self._transform_message(msg, chunk=chunk)
661+
# Act like nothing happened if transformed to None
662+
if msg is None:
665663
return
666-
# Store if this is a whole message or the end of a streaming message
667-
if chunk is False:
668-
self._store_message(msg_t)
664+
msg_store = msg
665+
# Transforming requires *replacing* content
666+
if isinstance(msg, TransformedMessage):
667+
operation = "replace"
669668
elif chunk == "end":
670-
# Transforming content requires replacing all the content, so take
671-
# it as is. Otherwise, store the accumulated stream message.
672-
if transformed:
673-
self._store_message(msg_t)
674-
else:
675-
self._store_message(
676-
TransformedMessage.from_message(
677-
ChatMessage(
678-
content=self._current_stream_message, role="assistant"
679-
)
680-
)
681-
)
669+
# When not transforming, ensure full message is stored
670+
msg_store = ChatMessage(
671+
content=self._current_stream_message,
672+
role="assistant",
673+
)
674+
# Only store full messages
675+
if chunk is False or chunk == "end":
676+
self._store_message(msg_store)
677+
# Send the message to the client
682678
await self._send_append_message(
683-
content=msg_t.content_client,
684-
role=msg_t.role,
679+
message=msg,
685680
chunk=chunk,
686-
operation="replace" if transformed else operation,
681+
operation=operation,
687682
icon=icon,
688683
)
689684
finally:
@@ -835,13 +830,15 @@ def _can_append_message(self, stream_id: str | None) -> bool:
835830
# Send a message to the UI
836831
async def _send_append_message(
837832
self,
838-
content: str | HTML,
839-
role: Role,
833+
message: TransformedMessage | ChatMessage,
840834
chunk: ChunkOption = False,
841835
operation: Literal["append", "replace"] = "append",
842836
icon: HTML | Tag | TagList | None = None,
843837
):
844-
if role == "system":
838+
if not isinstance(message, TransformedMessage):
839+
message = TransformedMessage.from_message(message)
840+
841+
if message.role == "system":
845842
# System messages are not displayed in the UI
846843
return
847844

@@ -856,12 +853,13 @@ async def _send_append_message(
856853
elif chunk == "end":
857854
chunk_type = "message_end"
858855

856+
content = message.content_client
859857
content_type = "html" if isinstance(content, HTML) else "markdown"
860858

861859
# TODO: pass along dependencies for both content and icon (if any)
862860
msg = ClientMessage(
863861
content=str(content),
864-
role=role,
862+
role=message.role,
865863
content_type=content_type,
866864
chunk_type=chunk_type,
867865
operation=operation,
@@ -996,14 +994,11 @@ async def _transform_message(
996994
self,
997995
message: ChatMessage,
998996
chunk: ChunkOption = False,
999-
) -> tuple[TransformedMessage | None, bool]:
997+
) -> ChatMessage | TransformedMessage | None:
1000998
res = TransformedMessage.from_message(message)
1001-
key = res.transform_key
1002-
transformed = False
1003999

10041000
if message["role"] == "user" and self._transform_user is not None:
10051001
content = await self._transform_user(message["content"])
1006-
transformed = True
10071002
elif message["role"] == "assistant" and self._transform_assistant is not None:
10081003
all_content = (
10091004
message["content"] if chunk is False else self._current_stream_message
@@ -1014,24 +1009,26 @@ async def _transform_message(
10141009
message["content"],
10151010
chunk == "end" or chunk is False,
10161011
)
1017-
transformed = True
10181012
else:
1019-
return (res, transformed)
1013+
return message
10201014

10211015
if content is None:
1022-
return (None, transformed)
1016+
return None
10231017

1024-
setattr(res, key, content)
1018+
setattr(res, res.transform_key, content)
10251019

1026-
return (res, transformed)
1020+
return res
10271021

10281022
# Just before storing, handle chunk msg type and calculate tokens
10291023
def _store_message(
10301024
self,
1031-
message: TransformedMessage,
1025+
message: TransformedMessage | ChatMessage,
10321026
index: int | None = None,
10331027
) -> None:
10341028

1029+
if not isinstance(message, TransformedMessage):
1030+
message = TransformedMessage.from_message(message)
1031+
10351032
with reactive.isolate():
10361033
messages = self._messages()
10371034

0 commit comments

Comments
 (0)