11from __future__ import annotations
22
33import inspect
4+ import warnings
45from typing import (
56 Any ,
67 AsyncIterable ,
3839 as_provider_message ,
3940)
4041from ._chat_tokenizer import TokenEncoding , TokenizersEncoding , get_default_tokenizer
41- from ._chat_types import ChatMessage , ClientMessage , TransformedMessage
42+ from ._chat_types import ChatMessage , ClientMessage , Role , TransformedMessage
4243from ._html_deps_py_shiny import chat_deps
4344from .fill import as_fill_item , as_fillable_container
4445
@@ -231,18 +232,18 @@ async def _init_chat():
231232 @reactive .effect (priority = 9999 )
232233 @reactive .event (self ._user_input )
233234 async def _on_user_input ():
234- msg = ChatMessage ( content = self ._user_input (), role = "user" )
235+ content = self ._user_input ()
235236 # It's possible that during the transform, a message is appended, so get
236237 # the length now, so we can insert the new message at the right index
237238 n_pre = len (self ._messages ())
238- msg_post = await self ._transform_message ( msg )
239- if msg_post is not None :
240- self ._store_message ( msg_post )
239+ content , _ = await self ._transform_content ( content , role = "user" )
240+ if content is not None :
241+ self ._store_content ( content , role = "user" )
241242 self ._suspend_input_handler = False
242243 else :
243244 # A transformed value of None is a special signal to suspend input
244245 # handling (i.e., don't generate a response)
245- self ._store_message ( as_transformed_message ( msg ) , index = n_pre )
246+ self ._store_content ( content or "" , role = "user" , index = n_pre )
246247 await self ._remove_loading_message ()
247248 self ._suspend_input_handler = True
248249
@@ -483,14 +484,15 @@ def messages(
483484 res : list [ChatMessage | ProviderMessage ] = []
484485 for i , m in enumerate (messages ):
485486 transform = False
486- if m [ " role" ] == "assistant" :
487+ if m . role == "assistant" :
487488 transform = transform_assistant
488- elif m [ " role" ] == "user" :
489+ elif m . role == "user" :
489490 transform = transform_user == "all" or (
490491 transform_user == "last" and i == len (messages ) - 1
491492 )
492- content_key = m ["transform_key" if transform else "pre_transform_key" ]
493- chat_msg = ChatMessage (content = str (m [content_key ]), role = m ["role" ])
493+ key = "transform_key" if transform else "pre_transform_key"
494+ content_val = getattr (m , getattr (m , key ))
495+ chat_msg = ChatMessage (content = str (content_val ), role = m .role )
494496 if not isinstance (format , MISSING_TYPE ):
495497 chat_msg = as_provider_message (chat_msg , format )
496498 res .append (chat_msg )
@@ -550,11 +552,89 @@ async def append_message(
550552 """
551553 await self ._append_message (message , icon = icon )
552554
555+ async def inject_message_chunk (
556+ self ,
557+ message_chunk : Any ,
558+ * ,
559+ operation : Literal ["append" , "replace" ] = "append" ,
560+ force : bool = False ,
561+ ):
562+ """
563+ Inject a chunk of message content into the current message stream.
564+
565+ Sometimes when streaming a message (i.e., `.append_message_stream()`), you may
566+ want to inject a content into the streaming message while the stream is
567+ busy doing other things (e.g., calling a tool). This method allows you to
568+ inject any content you want into the current message stream (assuming one is
569+ active).
570+
571+ Parameters
572+ ----------
573+ message_chunk
574+ A message chunk to inject.
575+ operation
576+ Whether to append or replace the current message stream content.
577+ force
578+ Whether to start a new stream if one is not currently active.
579+ """
580+ stream_id = self ._current_stream_id
581+ if stream_id is None :
582+ if not force :
583+ raise ValueError (
584+ "Can't inject a message chunk when no message stream is active. "
585+ "Use `force=True` to start a new stream if one is not currently active." ,
586+ )
587+ await self .start_message_stream (force = True )
588+
589+ return await self ._append_message (
590+ message_chunk ,
591+ chunk = True ,
592+ stream_id = stream_id ,
593+ operation = operation ,
594+ )
595+
596+ async def start_message_stream (self , * , force : bool = False ):
597+ """
598+ Start a new message stream.
599+
600+ Parameters
601+ ----------
602+ force
603+ Whether to force starting a new stream even if one is already active
604+ """
605+ stream_id = self ._current_stream_id
606+ if stream_id is not None :
607+ if not force :
608+ raise ValueError (
609+ "Can't start a new message stream when a message stream is already active. "
610+ "Use `force=True` to end a currently active stream and start a new one." ,
611+ )
612+ await self .end_message_stream ()
613+
614+ id = _utils .private_random_id ()
615+ return await self ._append_message ("" , chunk = "start" , stream_id = id )
616+
617+ async def end_message_stream (self ):
618+ """
619+ End the current message stream (if any).
620+ """
621+ stream_id = self ._current_stream_id
622+ if stream_id is None :
623+ warnings .warn ("No currently active stream to end." , stacklevel = 2 )
624+ return
625+
626+ with reactive .isolate ():
627+ # TODO: .cancel() method should probably just handle this
628+ self .latest_message_stream .cancel ()
629+
630+ return await self ._append_message ("" , chunk = "end" , stream_id = stream_id )
631+
553632 async def _append_message (
554633 self ,
555634 message : Any ,
556635 * ,
557636 chunk : ChunkOption = False ,
637+ operation : Literal ["append" , "replace" ] = "append" ,
558638 stream_id : str | None = None ,
559639 icon : HTML | Tag | TagList | None = None ,
560640 ) -> None :
@@ -570,27 +650,39 @@ async def _append_message(
570650
571651 if chunk is False :
572652 msg = normalize_message (message )
573- chunk_content = None
574653 else :
575654 msg = normalize_message_chunk (message )
576- # Update the current stream message
577- chunk_content = msg ["content" ]
578- self ._current_stream_message += chunk_content
579- msg ["content" ] = self ._current_stream_message
580- if chunk == "end" :
655+ if operation == "replace" :
581656 self ._current_stream_message = ""
657+ self ._current_stream_message += msg ["content" ]
582658
583- msg = await self ._transform_message (
584- msg , chunk = chunk , chunk_content = chunk_content
585- )
586- if msg is None :
587- return
588- self ._store_message (msg , chunk = chunk )
589- await self ._send_append_message (
590- msg ,
591- chunk = chunk ,
592- icon = icon ,
593- )
659+ try :
660+ content , transformed = await self ._transform_content (
661+ msg ["content" ], role = msg ["role" ], chunk = chunk
662+ )
663+ # Act like nothing happened if content transformed to None
664+ if content is None :
665+ return
666+ # Store if this is a whole message or the end of a streaming message
667+ if chunk is False :
668+ self ._store_content (content , role = msg ["role" ])
669+ elif chunk == "end" :
670+ # Transforming content requires replacing all the content, so take
671+ # it as is. Otherwise, store the accumulated stream message.
672+ self ._store_content (
673+ content = content if transformed else self ._current_stream_message ,
674+ role = msg ["role" ],
675+ )
676+ await self ._send_append_message (
677+ content = content ,
678+ role = msg ["role" ],
679+ chunk = chunk ,
680+ operation = "replace" if transformed else operation ,
681+ icon = icon ,
682+ )
683+ finally :
684+ if chunk == "end" :
685+ self ._current_stream_message = ""
594686
595687 async def append_message_stream (
596688 self ,
@@ -737,11 +829,13 @@ def _can_append_message(self, stream_id: str | None) -> bool:
737829 # Send a message to the UI
738830 async def _send_append_message (
739831 self ,
740- message : TransformedMessage ,
832+ content : str | HTML ,
833+ role : Role ,
741834 chunk : ChunkOption = False ,
835+ operation : Literal ["append" , "replace" ] = "append" ,
742836 icon : HTML | Tag | TagList | None = None ,
743837 ):
744- if message [ " role" ] == "system" :
838+ if role == "system" :
745839 # System messages are not displayed in the UI
746840 return
747841
@@ -756,15 +850,15 @@ async def _send_append_message(
756850 elif chunk == "end" :
757851 chunk_type = "message_end"
758852
759- content = message ["content_client" ]
760853 content_type = "html" if isinstance (content , HTML ) else "markdown"
761854
762855 # TODO: pass along dependencies for both content and icon (if any)
763856 msg = ClientMessage (
764857 content = str (content ),
765- role = message [ " role" ] ,
858+ role = role ,
766859 content_type = content_type ,
767860 chunk_type = chunk_type ,
861+ operation = operation ,
768862 )
769863
770864 if icon is not None :
@@ -892,57 +986,50 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
892986 else :
893987 return _set_transform (fn )
894988
895- async def _transform_message (
989+ async def _transform_content (
896990 self ,
897- message : ChatMessage ,
991+ content : str ,
992+ role : Role ,
898993 chunk : ChunkOption = False ,
899- chunk_content : str | None = None ,
900- ) -> TransformedMessage | None :
901- res = as_transformed_message (message )
902- key = res ["transform_key" ]
903-
904- if message ["role" ] == "user" and self ._transform_user is not None :
905- content = await self ._transform_user (message ["content" ])
906-
907- elif message ["role" ] == "assistant" and self ._transform_assistant is not None :
908- content = await self ._transform_assistant (
909- message ["content" ],
910- chunk_content or "" ,
994+ ) -> tuple [str | HTML | None , bool ]:
995+ content2 = content
996+ transformed = False
997+ if role == "user" and self ._transform_user is not None :
998+ content2 = await self ._transform_user (content )
999+ transformed = True
1000+ elif role == "assistant" and self ._transform_assistant is not None :
1001+ all_content = content if chunk is False else self ._current_stream_message
1002+ content2 = await self ._transform_assistant (
1003+ all_content ,
1004+ content ,
9111005 chunk == "end" or chunk is False ,
9121006 )
913- else :
914- return res
915-
916- if content is None :
917- return None
918-
919- res [key ] = content # type: ignore
1007+ transformed = True
9201008
921- return res
1009+ return ( content2 , transformed )
9221010
9231011 # Just before storing, handle chunk msg type and calculate tokens
924- def _store_message (
1012+ def _store_content (
9251013 self ,
926- message : TransformedMessage ,
927- chunk : ChunkOption = False ,
1014+ content : str | HTML ,
1015+ role : Role ,
9281016 index : int | None = None ,
9291017 ) -> None :
930- # Don't actually store chunks until the end
931- if chunk is True or chunk == "start" :
932- return None
9331018
9341019 with reactive .isolate ():
9351020 messages = self ._messages ()
9361021
9371022 if index is None :
9381023 index = len (messages )
9391024
1025+ msg = TransformedMessage .from_content (content = content , role = role )
1026+
9401027 messages = list (messages )
941- messages .insert (index , message )
1028+ messages .insert (index , msg )
9421029
9431030 self ._messages .set (tuple (messages ))
944- if message [ " role" ] == "user" :
945- self ._latest_user_input .set (message )
1031+ if role == "user" :
1032+ self ._latest_user_input .set (msg )
9461033
9471034 return None
9481035
@@ -966,9 +1053,9 @@ def _trim_messages(
9661053 n_other_messages : int = 0
9671054 token_counts : list [int ] = []
9681055 for m in messages :
969- count = self ._get_token_count (m [ " content_server" ] )
1056+ count = self ._get_token_count (m . content_server )
9701057 token_counts .append (count )
971- if m [ " role" ] == "system" :
1058+ if m . role == "system" :
9721059 n_system_tokens += count
9731060 n_system_messages += 1
9741061 else :
@@ -989,7 +1076,7 @@ def _trim_messages(
9891076 n_other_messages2 : int = 0
9901077 token_counts .reverse ()
9911078 for i , m in enumerate (reversed (messages )):
992- if m [ " role" ] == "system" :
1079+ if m . role == "system" :
9931080 messages2 .append (m )
9941081 continue
9951082 remaining_non_system_tokens -= token_counts [i ]
@@ -1012,13 +1099,13 @@ def _trim_anthropic_messages(
10121099 self ,
10131100 messages : tuple [TransformedMessage , ...],
10141101 ) -> tuple [TransformedMessage , ...]:
1015- if any (m [ " role" ] == "system" for m in messages ):
1102+ if any (m . role == "system" for m in messages ):
10161103 raise ValueError (
10171104 "Anthropic requires a system prompt to be specified in it's `.create()` method "
10181105 "(not in the chat messages with `role: system`)."
10191106 )
10201107 for i , m in enumerate (messages ):
1021- if m [ " role" ] == "user" :
1108+ if m . role == "user" :
10221109 return messages [i :]
10231110
10241111 return ()
@@ -1064,7 +1151,8 @@ def user_input(self, transform: bool = False) -> str | None:
10641151 if msg is None :
10651152 return None
10661153 key = "content_server" if transform else "content_client"
1067- return str (msg [key ])
1154+ val = getattr (msg , key )
1155+ return str (val )
10681156
10691157 def _user_input (self ) -> str :
10701158 id = self .user_input_id
@@ -1308,21 +1396,4 @@ def chat_ui(
13081396 return res
13091397
13101398
1311- def as_transformed_message (message : ChatMessage ) -> TransformedMessage :
1312- if message ["role" ] == "user" :
1313- transform_key = "content_server"
1314- pre_transform_key = "content_client"
1315- else :
1316- transform_key = "content_client"
1317- pre_transform_key = "content_server"
1318-
1319- return TransformedMessage (
1320- content_client = message ["content" ],
1321- content_server = message ["content" ],
1322- role = message ["role" ],
1323- transform_key = transform_key ,
1324- pre_transform_key = pre_transform_key ,
1325- )
1326-
1327-
13281399CHAT_INSTANCES : WeakValueDictionary [str , Chat ] = WeakValueDictionary ()
0 commit comments