11from __future__ import annotations
22
33import inspect
4- import warnings
54from contextlib import asynccontextmanager
65from typing import (
76 Any ,
8382
8483ChunkOption = Literal ["start" , "end" , True , False ]
8584
86- PendingMessage = Tuple [Any , ChunkOption , Union [str , None ]]
85+ PendingMessage = Tuple [Any , Literal [ "start" , "end" , True ] , Union [str , None ]]
8786
8887
8988@add_example (ex_dir = "../templates/chat/starters/hello" )
@@ -199,15 +198,12 @@ def __init__(
199198 self .on_error = on_error
200199
201200 # Chunked messages get accumulated (using this property) before changing state
202- self ._current_stream_message = ""
201+ self ._current_stream_message : str = ""
203202 self ._current_stream_id : str | None = None
204203 self ._pending_messages : list [PendingMessage ] = []
205204
206- # Identifier for a manual stream (i.e., one started with `.start_message_stream()`)
207- self ._manual_stream_id : str | None = None
208- # If a manual stream gets nested within another stream, we need to keep track of
209- # the accumulated message separately
210- self ._nested_stream_message : str = ""
205+ # For tracking message stream state when entering/exiting nested streams
206+ self ._message_stream_checkpoint : str = ""
211207
212208 # If a user input message is transformed into a response, we need to cancel
213209 # the next user input submit handling
@@ -576,7 +572,16 @@ async def append_message(
576572 similar) is specified in model's completion method.
577573 :::
578574 """
579- await self ._append_message (message , icon = icon )
575+ msg = normalize_message (message )
576+ msg = await self ._transform_message (msg )
577+ if msg is None :
578+ return
579+ self ._store_message (msg )
580+ await self ._send_append_message (
581+ message = msg ,
582+ chunk = False ,
583+ icon = icon ,
584+ )
580585
581586 async def append_message_chunk (
582587 self ,
@@ -618,9 +623,8 @@ async def append_message_chunk(
618623 "Use .message_stream() or .append_message_stream() to start one."
619624 )
620625
621- return await self ._append_message (
626+ return await self ._append_message_chunk (
622627 message_chunk ,
623- chunk = True ,
624628 stream_id = stream_id ,
625629 operation = operation ,
626630 )
@@ -641,75 +645,39 @@ async def message_stream(self):
641645 to display "ephemeral" content, then eventually show a final state
642646 with `.append_message_chunk(operation="replace")`.
643647 """
644- await self ._start_stream ()
648+ # Save the current stream state in a checkpoint (so that we can handle
649+ # ``.append_message_chunk(operation="replace")` correctly)
650+ old_checkpoint = self ._message_stream_checkpoint
651+ self ._message_stream_checkpoint = self ._current_stream_message
652+
653+ # No stream currently exists, start one
654+ is_root_stream = not self ._current_stream_id
655+ if is_root_stream :
656+ await self ._append_message_chunk (
657+ "" ,
658+ chunk = "start" ,
659+ stream_id = _utils .private_random_id (),
660+ )
661+
645662 try :
646663 yield
647664 finally :
648- await self ._end_stream ()
649-
650- async def _start_stream (self ):
651- if self ._manual_stream_id is not None :
652- # TODO: support this?
653- raise ValueError ("Nested .message_stream() isn't currently supported." )
654- # If we're currently streaming (i.e., through append_message_stream()), then
655- # end the client message stream (since we start a new one below)
656- if self ._current_stream_id is not None :
657- await self ._send_append_message (
658- message = ChatMessage (content = "" , role = "assistant" ),
659- chunk = "end" ,
660- operation = "append" ,
661- )
662- # Regardless whether this is an "inner" stream, we start a new message on the
663- # client so it can handle `operation="replace"` without having to track where
664- # the inner stream started.
665- self ._manual_stream_id = _utils .private_random_id ()
666- stream_id = self ._current_stream_id or self ._manual_stream_id
667- return await self ._append_message (
668- "" ,
669- chunk = "start" ,
670- stream_id = stream_id ,
671- # TODO: find a cleaner way to do this, and remove the gap between the messages
672- icon = (
673- HTML ("<span class='border-0'><span>" )
674- if self ._is_nested_stream
675- else None
676- ),
677- )
678-
679- async def _end_stream (self ):
680- if self ._manual_stream_id is None and self ._current_stream_id is None :
681- warnings .warn (
682- "Tried to end a message stream, but one isn't currently active." ,
683- stacklevel = 2 ,
684- )
685- return
686-
687- if self ._is_nested_stream :
688- # If inside another stream, just update server-side message state
689- self ._current_stream_message += self ._nested_stream_message
690- self ._nested_stream_message = ""
691- else :
692- # Otherwise, end this "manual" message stream
693- await self ._append_message (
694- "" , chunk = "end" , stream_id = self ._manual_stream_id
695- )
696-
697- self ._manual_stream_id = None
698- return
699-
700- @property
701- def _is_nested_stream (self ):
702- return (
703- self ._current_stream_id is not None
704- and self ._manual_stream_id is not None
705- and self ._current_stream_id != self ._manual_stream_id
706- )
665+ # Restore the previous stream state
666+ self ._message_stream_checkpoint = old_checkpoint
667+
668+ # If this was the root stream, end it
669+ if is_root_stream :
670+ await self ._append_message_chunk (
671+ "" ,
672+ chunk = "end" ,
673+ stream_id = self ._current_stream_id ,
674+ )
707675
708- async def _append_message (
676+ async def _append_message_chunk (
709677 self ,
710678 message : Any ,
711679 * ,
712- chunk : ChunkOption = False ,
680+ chunk : Literal [ True , "start" , "end" ] = True ,
713681 operation : Literal ["append" , "replace" ] = "append" ,
714682 stream_id : str | None = None ,
715683 icon : HTML | Tag | TagList | None = None ,
@@ -724,37 +692,40 @@ async def _append_message(
724692 if chunk == "end" :
725693 self ._current_stream_id = None
726694
727- if chunk is False :
728- msg = normalize_message (message )
695+ # Normalize into a ChatMessage()
696+ msg = normalize_message_chunk (message )
697+
698+ # Remember this content chunk for passing to transformer
699+ this_chunk = msg .content
700+
701+ # Transforming requires replacing
702+ if self ._needs_transform (msg ):
703+ operation = "replace"
704+
705+ if operation == "replace" :
706+ # Replace up to the latest checkpoint
707+ self ._current_stream_message = self ._message_stream_checkpoint + this_chunk
708+ msg .content = self ._current_stream_message
729709 else :
730- msg = normalize_message_chunk (message )
731- if self ._is_nested_stream :
732- if operation == "replace" :
733- self ._nested_stream_message = ""
734- self ._nested_stream_message += msg .content
735- else :
736- if operation == "replace" :
737- self ._current_stream_message = ""
738- self ._current_stream_message += msg .content
710+ self ._current_stream_message += msg .content
739711
740712 try :
741- msg = await self ._transform_message (msg , chunk = chunk )
742- # Act like nothing happened if transformed to None
743- if msg is None :
744- return
745- msg_store = msg
746- # Transforming requires *replacing* content
747- if isinstance (msg , TransformedMessage ):
748- operation = "replace"
713+ if self ._needs_transform (msg ):
714+ msg = await self ._transform_message (
715+ msg , chunk = chunk , chunk_content = this_chunk
716+ )
717+ # Act like nothing happened if transformed to None
718+ if msg is None :
719+ return
720+ if chunk == "end" :
721+ self ._store_message (msg )
749722 elif chunk == "end" :
750- # When not transforming, ensure full message is stored
751- msg_store = ChatMessage (
752- content = self ._current_stream_message ,
753- role = "assistant" ,
723+ # When `operation="append"`, msg.content is just a chunk, but we must
724+ # store the full message
725+ self ._store_message (
726+ ChatMessage ( content = self . _current_stream_message , role = msg . role )
754727 )
755- # Only store full messages
756- if chunk is False or chunk == "end" :
757- self ._store_message (msg_store )
728+
758729 # Send the message to the client
759730 await self ._send_append_message (
760731 message = msg ,
@@ -764,10 +735,8 @@ async def _append_message(
764735 )
765736 finally :
766737 if chunk == "end" :
767- if self ._is_nested_stream :
768- self ._nested_stream_message = ""
769- else :
770- self ._current_stream_message = ""
738+ self ._current_stream_message = ""
739+ self ._message_stream_checkpoint = ""
771740
772741 async def append_message_stream (
773742 self ,
@@ -898,21 +867,21 @@ async def _append_message_stream(
898867 id = _utils .private_random_id ()
899868
900869 empty = ChatMessageDict (content = "" , role = "assistant" )
901- await self ._append_message (empty , chunk = "start" , stream_id = id , icon = icon )
870+ await self ._append_message_chunk (empty , chunk = "start" , stream_id = id , icon = icon )
902871
903872 try :
904873 async for msg in message :
905- await self ._append_message (msg , chunk = True , stream_id = id )
874+ await self ._append_message_chunk (msg , chunk = True , stream_id = id )
906875 return self ._current_stream_message
907876 finally :
908- await self ._append_message (empty , chunk = "end" , stream_id = id )
877+ await self ._append_message_chunk (empty , chunk = "end" , stream_id = id )
909878 await self ._flush_pending_messages ()
910879
911880 async def _flush_pending_messages (self ):
912881 still_pending : list [PendingMessage ] = []
913882 for msg , chunk , stream_id in self ._pending_messages :
914883 if self ._can_append_message (stream_id ):
915- await self ._append_message (msg , chunk = chunk , stream_id = stream_id )
884+ await self ._append_message_chunk (msg , chunk = chunk , stream_id = stream_id )
916885 else :
917886 still_pending .append ((msg , chunk , stream_id ))
918887 self ._pending_messages = still_pending
@@ -1093,23 +1062,20 @@ async def _transform_message(
10931062 self ,
10941063 message : ChatMessage ,
10951064 chunk : ChunkOption = False ,
1096- ) -> ChatMessage | TransformedMessage | None :
1065+ chunk_content : str = "" ,
1066+ ) -> TransformedMessage | None :
10971067 res = TransformedMessage .from_chat_message (message )
10981068
10991069 if message .role == "user" and self ._transform_user is not None :
11001070 content = await self ._transform_user (message .content )
11011071 elif message .role == "assistant" and self ._transform_assistant is not None :
1102- all_content = (
1103- message .content if chunk is False else self ._current_stream_message
1104- )
1105- setattr (res , res .pre_transform_key , all_content )
11061072 content = await self ._transform_assistant (
1107- all_content ,
11081073 message .content ,
1074+ chunk_content ,
11091075 chunk == "end" or chunk is False ,
11101076 )
11111077 else :
1112- return message
1078+ return res
11131079
11141080 if content is None :
11151081 return None
@@ -1118,6 +1084,13 @@ async def _transform_message(
11181084
11191085 return res
11201086
1087+ def _needs_transform (self , message : ChatMessage ) -> bool :
1088+ if message .role == "user" and self ._transform_user is not None :
1089+ return True
1090+ elif message .role == "assistant" and self ._transform_assistant is not None :
1091+ return True
1092+ return False
1093+
11211094 # Just before storing, handle chunk msg type and calculate tokens
11221095 def _store_message (
11231096 self ,
0 commit comments