1
+ import asyncio
1
2
from abc import ABC , abstractmethod
2
3
from dataclasses import asdict
3
4
from logging import Logger
@@ -82,6 +83,10 @@ class BasePersona(ABC):
82
83
Automatically set by `BasePersona`.
83
84
"""
84
85
86
+ message_interrupted : dict [str , asyncio .Event ]
87
+ """Dictionary mapping an agent message identifier to an asyncio Event
88
+ which indicates if the message generation/streaming was interrupted."""
89
+
85
90
################################################
86
91
# constructor
87
92
################################################
@@ -92,11 +97,13 @@ def __init__(
92
97
manager : "PersonaManager" ,
93
98
config : ConfigManager ,
94
99
log : Logger ,
100
+ message_interrupted : dict [str , asyncio .Event ],
95
101
):
96
102
self .ychat = ychat
97
103
self .manager = manager
98
104
self .config = config
99
105
self .log = log
106
+ self .message_interrupted = message_interrupted
100
107
self .awareness = PersonaAwareness (
101
108
ychat = self .ychat , log = self .log , user = self .as_user ()
102
109
)
@@ -221,14 +228,34 @@ async def stream_message(self, reply_stream: "AsyncIterator") -> None:
221
228
- Automatically manages its awareness state to show writing status.
222
229
"""
223
230
stream_id : Optional [str ] = None
224
-
231
+ stream_interrupted = False
225
232
try :
226
233
self .awareness .set_local_state_field ("isWriting" , True )
227
234
async for chunk in reply_stream :
235
+ if (
236
+ stream_id
237
+ and stream_id in self .message_interrupted .keys ()
238
+ and self .message_interrupted [stream_id ].is_set ()
239
+ ):
240
+ try :
241
+ # notify the model provider that streaming was interrupted
242
+ # (this is essential to allow the model to stop generating)
243
+ await reply_stream .athrow ( # type:ignore[attr-defined]
244
+ GenerationInterrupted ()
245
+ )
246
+ except GenerationInterrupted :
247
+ # do not let the exception bubble up in case if
248
+ # the provider did not handle it
249
+ pass
250
+ stream_interrupted = True
251
+ break
252
+
228
253
if not stream_id :
229
254
stream_id = self .ychat .add_message (
230
255
NewMessage (body = "" , sender = self .id )
231
256
)
257
+ self .message_interrupted [stream_id ] = asyncio .Event ()
258
+ self .awareness .set_local_state_field ("isWriting" , stream_id )
232
259
233
260
assert stream_id
234
261
self .ychat .update_message (
@@ -248,9 +275,29 @@ async def stream_message(self, reply_stream: "AsyncIterator") -> None:
248
275
self .log .exception (e )
249
276
finally :
250
277
self .awareness .set_local_state_field ("isWriting" , False )
278
+ if stream_id :
279
+ # if stream was interrupted, add a tombstone
280
+ if stream_interrupted :
281
+ stream_tombstone = "\n \n (AI response stopped by user)"
282
+ self .ychat .update_message (
283
+ Message (
284
+ id = stream_id ,
285
+ body = stream_tombstone ,
286
+ time = time (),
287
+ sender = self .id ,
288
+ raw_time = False ,
289
+ ),
290
+ append = True ,
291
+ )
292
+ if stream_id in self .message_interrupted .keys ():
293
+ del self .message_interrupted [stream_id ]
251
294
252
295
def send_message (self , body : str ) -> None :
253
296
"""
254
297
Sends a new message to the chat from this persona.
255
298
"""
256
299
self .ychat .add_message (NewMessage (body = body , sender = self .id ))
300
+
301
+
302
+ class GenerationInterrupted (asyncio .CancelledError ):
303
+ """Exception raised when streaming is cancelled by the user"""
0 commit comments