Skip to content

Commit 916d2c0

Browse files
committed
feat(Chat): Add .start_message_stream(), .end_message_stream(), and .inject_message_chunk(). Append instead of replace messages unless transforms are used
1 parent 5f0bcbf commit 916d2c0

File tree

5 files changed

+302
-138
lines changed

5 files changed

+302
-138
lines changed

shiny/ui/_chat.py

Lines changed: 156 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import inspect
4+
import warnings
45
from typing import (
56
Any,
67
AsyncIterable,
@@ -38,7 +39,7 @@
3839
as_provider_message,
3940
)
4041
from ._chat_tokenizer import TokenEncoding, TokenizersEncoding, get_default_tokenizer
41-
from ._chat_types import ChatMessage, ClientMessage, TransformedMessage
42+
from ._chat_types import ChatMessage, ClientMessage, Role, TransformedMessage
4243
from ._html_deps_py_shiny import chat_deps
4344
from .fill import as_fill_item, as_fillable_container
4445

@@ -231,18 +232,18 @@ async def _init_chat():
231232
@reactive.effect(priority=9999)
232233
@reactive.event(self._user_input)
233234
async def _on_user_input():
234-
msg = ChatMessage(content=self._user_input(), role="user")
235+
content = self._user_input()
235236
# It's possible that during the transform, a message is appended, so get
236237
# the length now, so we can insert the new message at the right index
237238
n_pre = len(self._messages())
238-
msg_post = await self._transform_message(msg)
239-
if msg_post is not None:
240-
self._store_message(msg_post)
239+
content, _ = await self._transform_content(content, role="user")
240+
if content is not None:
241+
self._store_content(content, role="user")
241242
self._suspend_input_handler = False
242243
else:
243244
# A transformed value of None is a special signal to suspend input
244245
# handling (i.e., don't generate a response)
245-
self._store_message(as_transformed_message(msg), index=n_pre)
246+
self._store_content(content or "", role="user", index=n_pre)
246247
await self._remove_loading_message()
247248
self._suspend_input_handler = True
248249

@@ -483,14 +484,15 @@ def messages(
483484
res: list[ChatMessage | ProviderMessage] = []
484485
for i, m in enumerate(messages):
485486
transform = False
486-
if m["role"] == "assistant":
487+
if m.role == "assistant":
487488
transform = transform_assistant
488-
elif m["role"] == "user":
489+
elif m.role == "user":
489490
transform = transform_user == "all" or (
490491
transform_user == "last" and i == len(messages) - 1
491492
)
492-
content_key = m["transform_key" if transform else "pre_transform_key"]
493-
chat_msg = ChatMessage(content=str(m[content_key]), role=m["role"])
493+
key = "transform_key" if transform else "pre_transform_key"
494+
content_val = getattr(m, getattr(m, key))
495+
chat_msg = ChatMessage(content=str(content_val), role=m.role)
494496
if not isinstance(format, MISSING_TYPE):
495497
chat_msg = as_provider_message(chat_msg, format)
496498
res.append(chat_msg)
@@ -550,11 +552,89 @@ async def append_message(
550552
"""
551553
await self._append_message(message, icon=icon)
552554

555+
async def inject_message_chunk(
556+
self,
557+
message_chunk: Any,
558+
*,
559+
operation: Literal["append", "replace"] = "append",
560+
force: bool = False,
561+
):
562+
"""
563+
Inject a chunk of message content into the current message stream.
564+
565+
Sometimes when streaming a message (i.e., `.append_message_stream()`), you may
566+
want to inject a content into the streaming message while the stream is
567+
busy doing other things (e.g., calling a tool). This method allows you to
568+
inject any content you want into the current message stream (assuming one is
569+
active).
570+
571+
Parameters
572+
----------
573+
message_chunk
574+
A message chunk to inject.
575+
operation
576+
Whether to append or replace the current message stream content.
577+
force
578+
Whether to start a new stream if one is not currently active.
579+
"""
580+
stream_id = self._current_stream_id
581+
if stream_id is None:
582+
if not force:
583+
raise ValueError(
584+
"Can't inject a message chunk when no message stream is active. "
585+
"Use `force=True` to start a new stream if one is not currently active.",
586+
)
587+
await self.start_message_stream(force=True)
588+
589+
return await self._append_message(
590+
message_chunk,
591+
chunk=True,
592+
stream_id=stream_id,
593+
operation=operation,
594+
)
595+
596+
async def start_message_stream(self, *, force: bool = False):
597+
"""
598+
Start a new message stream.
599+
600+
Parameters
601+
----------
602+
force
603+
Whether to force starting a new stream even if one is already active
604+
"""
605+
stream_id = self._current_stream_id
606+
if stream_id is not None:
607+
if not force:
608+
raise ValueError(
609+
"Can't start a new message stream when a message stream is already active. "
610+
"Use `force=True` to end a currently active stream and start a new one.",
611+
)
612+
await self.end_message_stream()
613+
614+
id = _utils.private_random_id()
615+
return await self._append_message("", chunk="start", stream_id=id)
616+
617+
async def end_message_stream(self):
618+
"""
619+
End the current message stream (if any).
620+
"""
621+
stream_id = self._current_stream_id
622+
if stream_id is None:
623+
warnings.warn("No currently active stream to end.", stacklevel=2)
624+
return
625+
626+
with reactive.isolate():
627+
# TODO: .cancel() method should probably just handle this
628+
self.latest_message_stream.cancel()
629+
630+
return await self._append_message("", chunk="end", stream_id=stream_id)
631+
553632
async def _append_message(
554633
self,
555634
message: Any,
556635
*,
557636
chunk: ChunkOption = False,
637+
operation: Literal["append", "replace"] = "append",
558638
stream_id: str | None = None,
559639
icon: HTML | Tag | TagList | None = None,
560640
) -> None:
@@ -570,27 +650,39 @@ async def _append_message(
570650

571651
if chunk is False:
572652
msg = normalize_message(message)
573-
chunk_content = None
574653
else:
575654
msg = normalize_message_chunk(message)
576-
# Update the current stream message
577-
chunk_content = msg["content"]
578-
self._current_stream_message += chunk_content
579-
msg["content"] = self._current_stream_message
580-
if chunk == "end":
655+
if operation == "replace":
581656
self._current_stream_message = ""
657+
self._current_stream_message += msg["content"]
582658

583-
msg = await self._transform_message(
584-
msg, chunk=chunk, chunk_content=chunk_content
585-
)
586-
if msg is None:
587-
return
588-
self._store_message(msg, chunk=chunk)
589-
await self._send_append_message(
590-
msg,
591-
chunk=chunk,
592-
icon=icon,
593-
)
659+
try:
660+
content, transformed = await self._transform_content(
661+
msg["content"], role=msg["role"], chunk=chunk
662+
)
663+
# Act like nothing happened if content transformed to None
664+
if content is None:
665+
return
666+
# Store if this is a whole message or the end of a streaming message
667+
if chunk is False:
668+
self._store_content(content, role=msg["role"])
669+
elif chunk == "end":
670+
# Transforming content requires replacing all the content, so take
671+
# it as is. Otherwise, store the accumulated stream message.
672+
self._store_content(
673+
content=content if transformed else self._current_stream_message,
674+
role=msg["role"],
675+
)
676+
await self._send_append_message(
677+
content=content,
678+
role=msg["role"],
679+
chunk=chunk,
680+
operation="replace" if transformed else operation,
681+
icon=icon,
682+
)
683+
finally:
684+
if chunk == "end":
685+
self._current_stream_message = ""
594686

595687
async def append_message_stream(
596688
self,
@@ -737,11 +829,13 @@ def _can_append_message(self, stream_id: str | None) -> bool:
737829
# Send a message to the UI
738830
async def _send_append_message(
739831
self,
740-
message: TransformedMessage,
832+
content: str | HTML,
833+
role: Role,
741834
chunk: ChunkOption = False,
835+
operation: Literal["append", "replace"] = "append",
742836
icon: HTML | Tag | TagList | None = None,
743837
):
744-
if message["role"] == "system":
838+
if role == "system":
745839
# System messages are not displayed in the UI
746840
return
747841

@@ -756,15 +850,15 @@ async def _send_append_message(
756850
elif chunk == "end":
757851
chunk_type = "message_end"
758852

759-
content = message["content_client"]
760853
content_type = "html" if isinstance(content, HTML) else "markdown"
761854

762855
# TODO: pass along dependencies for both content and icon (if any)
763856
msg = ClientMessage(
764857
content=str(content),
765-
role=message["role"],
858+
role=role,
766859
content_type=content_type,
767860
chunk_type=chunk_type,
861+
operation=operation,
768862
)
769863

770864
if icon is not None:
@@ -892,57 +986,50 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
892986
else:
893987
return _set_transform(fn)
894988

895-
async def _transform_message(
989+
async def _transform_content(
896990
self,
897-
message: ChatMessage,
991+
content: str,
992+
role: Role,
898993
chunk: ChunkOption = False,
899-
chunk_content: str | None = None,
900-
) -> TransformedMessage | None:
901-
res = as_transformed_message(message)
902-
key = res["transform_key"]
903-
904-
if message["role"] == "user" and self._transform_user is not None:
905-
content = await self._transform_user(message["content"])
906-
907-
elif message["role"] == "assistant" and self._transform_assistant is not None:
908-
content = await self._transform_assistant(
909-
message["content"],
910-
chunk_content or "",
994+
) -> tuple[str | HTML | None, bool]:
995+
content2 = content
996+
transformed = False
997+
if role == "user" and self._transform_user is not None:
998+
content2 = await self._transform_user(content)
999+
transformed = True
1000+
elif role == "assistant" and self._transform_assistant is not None:
1001+
all_content = content if chunk is False else self._current_stream_message
1002+
content2 = await self._transform_assistant(
1003+
all_content,
1004+
content,
9111005
chunk == "end" or chunk is False,
9121006
)
913-
else:
914-
return res
915-
916-
if content is None:
917-
return None
918-
919-
res[key] = content # type: ignore
1007+
transformed = True
9201008

921-
return res
1009+
return (content2, transformed)
9221010

9231011
# Just before storing, handle chunk msg type and calculate tokens
924-
def _store_message(
1012+
def _store_content(
9251013
self,
926-
message: TransformedMessage,
927-
chunk: ChunkOption = False,
1014+
content: str | HTML,
1015+
role: Role,
9281016
index: int | None = None,
9291017
) -> None:
930-
# Don't actually store chunks until the end
931-
if chunk is True or chunk == "start":
932-
return None
9331018

9341019
with reactive.isolate():
9351020
messages = self._messages()
9361021

9371022
if index is None:
9381023
index = len(messages)
9391024

1025+
msg = TransformedMessage.from_content(content=content, role=role)
1026+
9401027
messages = list(messages)
941-
messages.insert(index, message)
1028+
messages.insert(index, msg)
9421029

9431030
self._messages.set(tuple(messages))
944-
if message["role"] == "user":
945-
self._latest_user_input.set(message)
1031+
if role == "user":
1032+
self._latest_user_input.set(msg)
9461033

9471034
return None
9481035

@@ -966,9 +1053,9 @@ def _trim_messages(
9661053
n_other_messages: int = 0
9671054
token_counts: list[int] = []
9681055
for m in messages:
969-
count = self._get_token_count(m["content_server"])
1056+
count = self._get_token_count(m.content_server)
9701057
token_counts.append(count)
971-
if m["role"] == "system":
1058+
if m.role == "system":
9721059
n_system_tokens += count
9731060
n_system_messages += 1
9741061
else:
@@ -989,7 +1076,7 @@ def _trim_messages(
9891076
n_other_messages2: int = 0
9901077
token_counts.reverse()
9911078
for i, m in enumerate(reversed(messages)):
992-
if m["role"] == "system":
1079+
if m.role == "system":
9931080
messages2.append(m)
9941081
continue
9951082
remaining_non_system_tokens -= token_counts[i]
@@ -1012,13 +1099,13 @@ def _trim_anthropic_messages(
10121099
self,
10131100
messages: tuple[TransformedMessage, ...],
10141101
) -> tuple[TransformedMessage, ...]:
1015-
if any(m["role"] == "system" for m in messages):
1102+
if any(m.role == "system" for m in messages):
10161103
raise ValueError(
10171104
"Anthropic requires a system prompt to be specified in it's `.create()` method "
10181105
"(not in the chat messages with `role: system`)."
10191106
)
10201107
for i, m in enumerate(messages):
1021-
if m["role"] == "user":
1108+
if m.role == "user":
10221109
return messages[i:]
10231110

10241111
return ()
@@ -1064,7 +1151,8 @@ def user_input(self, transform: bool = False) -> str | None:
10641151
if msg is None:
10651152
return None
10661153
key = "content_server" if transform else "content_client"
1067-
return str(msg[key])
1154+
val = getattr(msg, key)
1155+
return str(val)
10681156

10691157
def _user_input(self) -> str:
10701158
id = self.user_input_id
@@ -1308,21 +1396,4 @@ def chat_ui(
13081396
return res
13091397

13101398

1311-
def as_transformed_message(message: ChatMessage) -> TransformedMessage:
1312-
if message["role"] == "user":
1313-
transform_key = "content_server"
1314-
pre_transform_key = "content_client"
1315-
else:
1316-
transform_key = "content_client"
1317-
pre_transform_key = "content_server"
1318-
1319-
return TransformedMessage(
1320-
content_client=message["content"],
1321-
content_server=message["content"],
1322-
role=message["role"],
1323-
transform_key=transform_key,
1324-
pre_transform_key=pre_transform_key,
1325-
)
1326-
1327-
13281399
CHAT_INSTANCES: WeakValueDictionary[str, Chat] = WeakValueDictionary()

0 commit comments

Comments
 (0)