Skip to content

Commit 4c32a0d

Browse files
committed
Go back to a more minimal change
1 parent 916d2c0 commit 4c32a0d

File tree

2 files changed

+55
-41
lines changed

2 files changed

+55
-41
lines changed

shiny/ui/_chat.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,20 @@ async def _init_chat():
232232
@reactive.effect(priority=9999)
233233
@reactive.event(self._user_input)
234234
async def _on_user_input():
235-
content = self._user_input()
235+
msg = ChatMessage(content=self._user_input(), role="user")
236236
# It's possible that during the transform, a message is appended, so get
237237
# the length now, so we can insert the new message at the right index
238238
n_pre = len(self._messages())
239-
content, _ = await self._transform_content(content, role="user")
240-
if content is not None:
241-
self._store_content(content, role="user")
239+
msg_post, _ = await self._transform_message(msg)
240+
if msg_post is not None:
241+
self._store_message(msg_post)
242242
self._suspend_input_handler = False
243243
else:
244244
# A transformed value of None is a special signal to suspend input
245245
# handling (i.e., don't generate a response)
246-
self._store_content(content or "", role="user", index=n_pre)
246+
self._store_message(
247+
TransformedMessage.from_message(msg), index=n_pre
248+
)
247249
await self._remove_loading_message()
248250
self._suspend_input_handler = True
249251

@@ -657,25 +659,29 @@ async def _append_message(
657659
self._current_stream_message += msg["content"]
658660

659661
try:
660-
content, transformed = await self._transform_content(
661-
msg["content"], role=msg["role"], chunk=chunk
662-
)
662+
msg_t, transformed = await self._transform_message(msg, chunk=chunk)
663663
# Act like nothing happened if content transformed to None
664-
if content is None:
664+
if msg_t is None:
665665
return
666666
# Store if this is a whole message or the end of a streaming message
667667
if chunk is False:
668-
self._store_content(content, role=msg["role"])
668+
self._store_message(msg_t)
669669
elif chunk == "end":
670670
# Transforming content requires replacing all the content, so take
671671
# it as is. Otherwise, store the accumulated stream message.
672-
self._store_content(
673-
content=content if transformed else self._current_stream_message,
674-
role=msg["role"],
675-
)
672+
if transformed:
673+
self._store_message(msg_t)
674+
else:
675+
self._store_message(
676+
TransformedMessage.from_message(
677+
ChatMessage(
678+
content=self._current_stream_message, role="assistant"
679+
)
680+
)
681+
)
676682
await self._send_append_message(
677-
content=content,
678-
role=msg["role"],
683+
content=msg_t.content_client,
684+
role=msg_t.role,
679685
chunk=chunk,
680686
operation="replace" if transformed else operation,
681687
icon=icon,
@@ -986,33 +992,43 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
986992
else:
987993
return _set_transform(fn)
988994

989-
async def _transform_content(
995+
async def _transform_message(
990996
self,
991-
content: str,
992-
role: Role,
997+
message: ChatMessage,
993998
chunk: ChunkOption = False,
994-
) -> tuple[str | HTML | None, bool]:
995-
content2 = content
999+
) -> tuple[TransformedMessage | None, bool]:
1000+
res = TransformedMessage.from_message(message)
1001+
key = res.transform_key
9961002
transformed = False
997-
if role == "user" and self._transform_user is not None:
998-
content2 = await self._transform_user(content)
1003+
1004+
if message["role"] == "user" and self._transform_user is not None:
1005+
content = await self._transform_user(message["content"])
9991006
transformed = True
1000-
elif role == "assistant" and self._transform_assistant is not None:
1001-
all_content = content if chunk is False else self._current_stream_message
1002-
content2 = await self._transform_assistant(
1007+
elif message["role"] == "assistant" and self._transform_assistant is not None:
1008+
all_content = (
1009+
message["content"] if chunk is False else self._current_stream_message
1010+
)
1011+
setattr(res, res.pre_transform_key, all_content)
1012+
content = await self._transform_assistant(
10031013
all_content,
1004-
content,
1014+
message["content"],
10051015
chunk == "end" or chunk is False,
10061016
)
10071017
transformed = True
1018+
else:
1019+
return (res, transformed)
1020+
1021+
if content is None:
1022+
return (None, transformed)
1023+
1024+
setattr(res, key, content)
10081025

1009-
return (content2, transformed)
1026+
return (res, transformed)
10101027

10111028
# Just before storing, handle chunk msg type and calculate tokens
1012-
def _store_content(
1029+
def _store_message(
10131030
self,
1014-
content: str | HTML,
1015-
role: Role,
1031+
message: TransformedMessage,
10161032
index: int | None = None,
10171033
) -> None:
10181034

@@ -1022,14 +1038,12 @@ def _store_content(
10221038
if index is None:
10231039
index = len(messages)
10241040

1025-
msg = TransformedMessage.from_content(content=content, role=role)
1026-
10271041
messages = list(messages)
1028-
messages.insert(index, msg)
1042+
messages.insert(index, message)
10291043

10301044
self._messages.set(tuple(messages))
1031-
if role == "user":
1032-
self._latest_user_input.set(msg)
1045+
if message.role == "user":
1046+
self._latest_user_input.set(message)
10331047

10341048
return None
10351049

shiny/ui/_chat_types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,18 @@ class TransformedMessage:
2727
pre_transform_key: Literal["content_client", "content_server"]
2828

2929
@classmethod
30-
def from_content(cls, content: str | HTML, role: Role) -> TransformedMessage:
31-
if role == "user":
30+
def from_message(cls, message: ChatMessage) -> TransformedMessage:
31+
if message["role"] == "user":
3232
transform_key = "content_server"
3333
pre_transform_key = "content_client"
3434
else:
3535
transform_key = "content_client"
3636
pre_transform_key = "content_server"
3737

3838
return cls(
39-
content_client=content,
40-
content_server=str(content),
41-
role=role,
39+
content_client=message["content"],
40+
content_server=message["content"],
41+
role=message["role"],
4242
transform_key=transform_key,
4343
pre_transform_key=pre_transform_key,
4444
)

0 commit comments

Comments
 (0)