@@ -590,12 +590,7 @@ async def append_message(
590590 icon = icon ,
591591 )
592592
593- async def append_message_chunk (
594- self ,
595- message_chunk : Any ,
596- * ,
597- operation : Literal ["append" , "replace" ] = "append" ,
598- ):
593+ async def append_message_chunk (self , message_chunk : Any ):
599594 """
600595 Append a message chunk to the current message stream.
601596
@@ -606,8 +601,6 @@ async def append_message_chunk(
606601 ----------
607602 message_chunk
608603 A message chunk to inject.
609- operation
610- Whether to append or replace the *current* message stream content.
611604
612605 Note
613606 ----
@@ -633,7 +626,6 @@ async def append_message_chunk(
633626 return await self ._append_message_chunk (
634627 message_chunk ,
635628 stream_id = stream_id ,
636- operation = operation ,
637629 )
638630
639631 @asynccontextmanager
@@ -644,6 +636,12 @@ async def message_stream(self):
644636 A context manager for streaming messages into the chat. Note this stream
645637 can occur within a longer running `.append_message_stream()` or used on its own.
646638
639+ Yields
640+ ------
641+ :
642+ A `MessageStream` instance with a method for `.append()`ing message chunks
643+ and a method for `.restore()`ing the stream back to it's initial state.
644+
647645 Note
648646 ----
649647 A useful pattern for displaying tool calls in a chat interface is for the
@@ -658,16 +656,14 @@ async def message_stream(self):
658656 self ._message_stream_checkpoint = self ._current_stream_message
659657
660658 # No stream currently exists, start one
661- is_root_stream = not self ._current_stream_id
659+ stream_id = self ._current_stream_id
660+ is_root_stream = stream_id is None
662661 if is_root_stream :
663- await self ._append_message_chunk (
664- "" ,
665- chunk = "start" ,
666- stream_id = _utils .private_random_id (),
667- )
662+ stream_id = _utils .private_random_id ()
663+ await self ._append_message_chunk ("" , chunk = "start" , stream_id = stream_id )
668664
669665 try :
670- yield
666+ yield MessageStream ( self , stream_id )
671667 finally :
672668 # Restore the previous stream state
673669 self ._message_stream_checkpoint = old_checkpoint
@@ -677,7 +673,7 @@ async def message_stream(self):
677673 await self ._append_message_chunk (
678674 "" ,
679675 chunk = "end" ,
680- stream_id = cast ( str , self . _current_stream_id ) ,
676+ stream_id = stream_id ,
681677 )
682678
683679 async def _append_message_chunk (
@@ -1496,4 +1492,36 @@ def chat_ui(
14961492 return res
14971493
14981494
1495+ class MessageStream :
1496+ """"""
1497+
1498+ def __init__ (self , chat : Chat , stream_id : str ):
1499+ self ._chat = chat
1500+ self ._stream_id = stream_id
1501+
1502+ async def restore (self ):
1503+ """
1504+ Restore the stream back to its initial state.
1505+ """
1506+ await self ._chat ._append_message_chunk (
1507+ "" ,
1508+ operation = "replace" ,
1509+ stream_id = self ._stream_id ,
1510+ )
1511+
1512+ async def append (self , message_chunk : Any ):
1513+ """
1514+ Append a message chunk to the stream.
1515+
1516+ Parameters
1517+ -----------
1518+ message_chunk
1519+ A message chunk to append to this stream
1520+ """
1521+ await self ._chat ._append_message_chunk (
1522+ message_chunk ,
1523+ stream_id = self ._stream_id ,
1524+ )
1525+
1526+
14991527CHAT_INSTANCES : WeakValueDictionary [str , Chat ] = WeakValueDictionary ()
0 commit comments