11from __future__ import annotations
22
33import inspect
4- import warnings
54from contextlib import asynccontextmanager
65from typing import (
76 Any ,
8180 UserSubmitFunction1 ,
8281]
8382
84- ChunkOption = Literal ["start" , "end" , True , False ]
85-
86- PendingMessage = Tuple [Any , ChunkOption , Union [str , None ]]
83+ PendingMessage = Tuple [Any , Literal ["start" , "end" , True ], Union [str , None ]]
8784
8885
8986@add_example (ex_dir = "../templates/chat/starters/hello" )
@@ -199,15 +196,12 @@ def __init__(
199196 self .on_error = on_error
200197
201198 # Chunked messages get accumulated (using this property) before changing state
202- self ._current_stream_message = ""
199+ self ._current_stream_message : str = ""
203200 self ._current_stream_id : str | None = None
204201 self ._pending_messages : list [PendingMessage ] = []
205202
206- # Identifier for a manual stream (i.e., one started with `.start_message_stream()`)
207- self ._manual_stream_id : str | None = None
208- # If a manual stream gets nested within another stream, we need to keep track of
209- # the accumulated message separately
210- self ._nested_stream_message : str = ""
203+ # For tracking message stream state when entering/exiting nested streams
204+ self ._message_stream_checkpoint : str = ""
211205
212206 # If a user input message is transformed into a response, we need to cancel
213207 # the next user input submit handling
@@ -576,7 +570,16 @@ async def append_message(
576570 similar) is specified in model's completion method.
577571 :::
578572 """
579- await self ._append_message (message , icon = icon )
573+ msg = normalize_message (message )
574+ msg = await self ._transform_message (msg )
575+ if msg is None :
576+ return
577+ self ._store_message (msg )
578+ await self ._send_append_message (
579+ message = msg ,
580+ chunk = False ,
581+ icon = icon ,
582+ )
580583
581584 async def append_message_chunk (
582585 self ,
@@ -618,9 +621,8 @@ async def append_message_chunk(
618621 "Use .message_stream() or .append_message_stream() to start one."
619622 )
620623
621- return await self ._append_message (
624+ return await self ._append_message_chunk (
622625 message_chunk ,
623- chunk = True ,
624626 stream_id = stream_id ,
625627 operation = operation ,
626628 )
@@ -641,75 +643,39 @@ async def message_stream(self):
641643 to display "ephemeral" content, then eventually show a final state
642644 with `.append_message_chunk(operation="replace")`.
643645 """
644- await self ._start_stream ()
646+ # Save the current stream state in a checkpoint (so that we can handle
647+ # ``.append_message_chunk(operation="replace")` correctly)
648+ old_checkpoint = self ._message_stream_checkpoint
649+ self ._message_stream_checkpoint = self ._current_stream_message
650+
651+ # No stream currently exists, start one
652+ is_root_stream = not self ._current_stream_id
653+ if is_root_stream :
654+ await self ._append_message_chunk (
655+ "" ,
656+ chunk = "start" ,
657+ stream_id = _utils .private_random_id (),
658+ )
659+
645660 try :
646661 yield
647662 finally :
648- await self ._end_stream ()
649-
650- async def _start_stream (self ):
651- if self ._manual_stream_id is not None :
652- # TODO: support this?
653- raise ValueError ("Nested .message_stream() isn't currently supported." )
654- # If we're currently streaming (i.e., through append_message_stream()), then
655- # end the client message stream (since we start a new one below)
656- if self ._current_stream_id is not None :
657- await self ._send_append_message (
658- message = ChatMessage (content = "" , role = "assistant" ),
659- chunk = "end" ,
660- operation = "append" ,
661- )
662- # Regardless whether this is an "inner" stream, we start a new message on the
663- # client so it can handle `operation="replace"` without having to track where
664- # the inner stream started.
665- self ._manual_stream_id = _utils .private_random_id ()
666- stream_id = self ._current_stream_id or self ._manual_stream_id
667- return await self ._append_message (
668- "" ,
669- chunk = "start" ,
670- stream_id = stream_id ,
671- # TODO: find a cleaner way to do this, and remove the gap between the messages
672- icon = (
673- HTML ("<span class='border-0'><span>" )
674- if self ._is_nested_stream
675- else None
676- ),
677- )
678-
679- async def _end_stream (self ):
680- if self ._manual_stream_id is None and self ._current_stream_id is None :
681- warnings .warn (
682- "Tried to end a message stream, but one isn't currently active." ,
683- stacklevel = 2 ,
684- )
685- return
686-
687- if self ._is_nested_stream :
688- # If inside another stream, just update server-side message state
689- self ._current_stream_message += self ._nested_stream_message
690- self ._nested_stream_message = ""
691- else :
692- # Otherwise, end this "manual" message stream
693- await self ._append_message (
694- "" , chunk = "end" , stream_id = self ._manual_stream_id
695- )
696-
697- self ._manual_stream_id = None
698- return
699-
700- @property
701- def _is_nested_stream (self ):
702- return (
703- self ._current_stream_id is not None
704- and self ._manual_stream_id is not None
705- and self ._current_stream_id != self ._manual_stream_id
706- )
663+ # Restore the previous stream state
664+ self ._message_stream_checkpoint = old_checkpoint
665+
666+ # If this was the root stream, end it
667+ if is_root_stream :
668+ await self ._append_message_chunk (
669+ "" ,
670+ chunk = "end" ,
671+ stream_id = self ._current_stream_id ,
672+ )
707673
708- async def _append_message (
674+ async def _append_message_chunk (
709675 self ,
710676 message : Any ,
711677 * ,
712- chunk : ChunkOption = False ,
678+ chunk : Literal [ True , "start" , "end" ] = True ,
713679 operation : Literal ["append" , "replace" ] = "append" ,
714680 stream_id : str | None = None ,
715681 icon : HTML | Tag | TagList | None = None ,
@@ -724,37 +690,40 @@ async def _append_message(
724690 if chunk == "end" :
725691 self ._current_stream_id = None
726692
727- if chunk is False :
728- msg = normalize_message (message )
693+ # Normalize into a ChatMessage()
694+ msg = normalize_message_chunk (message )
695+
696+ # Remember this content chunk for passing to transformer
697+ this_chunk = msg .content
698+
699+ # Transforming requires replacing
700+ if self ._needs_transform (msg ):
701+ operation = "replace"
702+
703+ if operation == "replace" :
704+ # Replace up to the latest checkpoint
705+ self ._current_stream_message = self ._message_stream_checkpoint + this_chunk
706+ msg .content = self ._current_stream_message
729707 else :
730- msg = normalize_message_chunk (message )
731- if self ._is_nested_stream :
732- if operation == "replace" :
733- self ._nested_stream_message = ""
734- self ._nested_stream_message += msg .content
735- else :
736- if operation == "replace" :
737- self ._current_stream_message = ""
738- self ._current_stream_message += msg .content
708+ self ._current_stream_message += msg .content
739709
740710 try :
741- msg = await self ._transform_message (msg , chunk = chunk )
742- # Act like nothing happened if transformed to None
743- if msg is None :
744- return
745- msg_store = msg
746- # Transforming requires *replacing* content
747- if isinstance (msg , TransformedMessage ):
748- operation = "replace"
711+ if self ._needs_transform (msg ):
712+ msg = await self ._transform_message (
713+ msg , chunk = chunk , chunk_content = this_chunk
714+ )
715+ # Act like nothing happened if transformed to None
716+ if msg is None :
717+ return
718+ if chunk == "end" :
719+ self ._store_message (msg )
749720 elif chunk == "end" :
750- # When not transforming, ensure full message is stored
751- msg_store = ChatMessage (
752- content = self ._current_stream_message ,
753- role = "assistant" ,
721+ # When `operation="append"`, msg.content is just a chunk, but we must
722+ # store the full message
723+ self ._store_message (
724+ ChatMessage ( content = self . _current_stream_message , role = msg . role )
754725 )
755- # Only store full messages
756- if chunk is False or chunk == "end" :
757- self ._store_message (msg_store )
726+
758727 # Send the message to the client
759728 await self ._send_append_message (
760729 message = msg ,
@@ -764,10 +733,8 @@ async def _append_message(
764733 )
765734 finally :
766735 if chunk == "end" :
767- if self ._is_nested_stream :
768- self ._nested_stream_message = ""
769- else :
770- self ._current_stream_message = ""
736+ self ._current_stream_message = ""
737+ self ._message_stream_checkpoint = ""
771738
772739 async def append_message_stream (
773740 self ,
@@ -898,21 +865,21 @@ async def _append_message_stream(
898865 id = _utils .private_random_id ()
899866
900867 empty = ChatMessageDict (content = "" , role = "assistant" )
901- await self ._append_message (empty , chunk = "start" , stream_id = id , icon = icon )
868+ await self ._append_message_chunk (empty , chunk = "start" , stream_id = id , icon = icon )
902869
903870 try :
904871 async for msg in message :
905- await self ._append_message (msg , chunk = True , stream_id = id )
872+ await self ._append_message_chunk (msg , chunk = True , stream_id = id )
906873 return self ._current_stream_message
907874 finally :
908- await self ._append_message (empty , chunk = "end" , stream_id = id )
875+ await self ._append_message_chunk (empty , chunk = "end" , stream_id = id )
909876 await self ._flush_pending_messages ()
910877
911878 async def _flush_pending_messages (self ):
912879 still_pending : list [PendingMessage ] = []
913880 for msg , chunk , stream_id in self ._pending_messages :
914881 if self ._can_append_message (stream_id ):
915- await self ._append_message (msg , chunk = chunk , stream_id = stream_id )
882+ await self ._append_message_chunk (msg , chunk = chunk , stream_id = stream_id )
916883 else :
917884 still_pending .append ((msg , chunk , stream_id ))
918885 self ._pending_messages = still_pending
@@ -926,7 +893,7 @@ def _can_append_message(self, stream_id: str | None) -> bool:
926893 async def _send_append_message (
927894 self ,
928895 message : TransformedMessage | ChatMessage ,
929- chunk : ChunkOption = False ,
896+ chunk : Literal [ "start" , "end" , True , False ] = False ,
930897 operation : Literal ["append" , "replace" ] = "append" ,
931898 icon : HTML | Tag | TagList | None = None ,
932899 ):
@@ -1092,32 +1059,35 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
10921059 async def _transform_message (
10931060 self ,
10941061 message : ChatMessage ,
1095- chunk : ChunkOption = False ,
1096- ) -> ChatMessage | TransformedMessage | None :
1062+ chunk : Literal ["start" , "end" , True , False ] = False ,
1063+ chunk_content : str = "" ,
1064+ ) -> TransformedMessage | None :
10971065 res = TransformedMessage .from_chat_message (message )
10981066
10991067 if message .role == "user" and self ._transform_user is not None :
11001068 content = await self ._transform_user (message .content )
11011069 elif message .role == "assistant" and self ._transform_assistant is not None :
1102- all_content = (
1103- message .content if chunk is False else self ._current_stream_message
1104- )
1105- setattr (res , res .pre_transform_key , all_content )
11061070 content = await self ._transform_assistant (
1107- all_content ,
11081071 message .content ,
1072+ chunk_content ,
11091073 chunk == "end" or chunk is False ,
11101074 )
11111075 else :
1112- return message
1076+ return res
11131077
11141078 if content is None :
11151079 return None
11161080
11171081 setattr (res , res .transform_key , content )
1118-
11191082 return res
11201083
1084+ def _needs_transform (self , message : ChatMessage ) -> bool :
1085+ if message .role == "user" and self ._transform_user is not None :
1086+ return True
1087+ elif message .role == "assistant" and self ._transform_assistant is not None :
1088+ return True
1089+ return False
1090+
11211091 # Just before storing, handle chunk msg type and calculate tokens
11221092 def _store_message (
11231093 self ,
0 commit comments