Skip to content

Commit 6b3d406

Browse files
Implement stop streaming button in v3 (#1351)
* Add handler to stop message streaming * Add the stop streaming button in message footer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Set explicitelly bot user, and listen only from message written by a bot to stop streaming * lint * Update jupyterlab-chat dependency * lint * Fix assertion error --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dd21873 commit 6b3d406

File tree

10 files changed

+204
-18
lines changed

10 files changed

+204
-18
lines changed

packages/jupyter-ai/jupyter_ai/chat_handlers/base.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,19 +452,46 @@ async def stream_reply(
452452
chunk_generator = self.llm_chain.astream(input, config=merged_config)
453453
# TODO v3: re-implement stream interrupt
454454
stream_interrupted = False
455+
stream_id = None
456+
received_first_chunk = False
455457
async for chunk in chunk_generator:
458+
if (
459+
stream_id
460+
and stream_id in self.message_interrupted.keys()
461+
and self.message_interrupted[stream_id].is_set()
462+
):
463+
try:
464+
# notify the model provider that streaming was interrupted
465+
# (this is essential to allow the model to stop generating)
466+
await chunk_generator.athrow( # type:ignore[attr-defined]
467+
GenerationInterrupted()
468+
)
469+
except GenerationInterrupted:
470+
# do not let the exception bubble up in case if
471+
# the provider did not handle it
472+
pass
473+
stream_interrupted = True
474+
break
475+
456476
if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
457-
reply_stream.write(chunk.content)
477+
stream_id = reply_stream.write(chunk.content)
458478
elif isinstance(chunk, str):
459-
reply_stream.write(chunk)
479+
stream_id = reply_stream.write(chunk)
460480
else:
461481
self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}")
462482
break
463483

484+
if not received_first_chunk:
485+
# when receiving the first chunk, start the stream.
486+
received_first_chunk = True
487+
self.message_interrupted[stream_id] = asyncio.Event()
488+
464489
# if stream was interrupted, add a tombstone
465490
if stream_interrupted:
466491
stream_tombstone = "\n\n(AI response stopped by user)"
467492
reply_stream.write(stream_tombstone)
493+
if stream_id and stream_id in self.message_interrupted.keys():
494+
del self.message_interrupted[stream_id]
468495

469496

470497
class GenerationInterrupted(asyncio.CancelledError):

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
AutocompleteOptionsHandler,
2929
EmbeddingsModelProviderHandler,
3030
GlobalConfigHandler,
31+
InterruptStreamingHandler,
3132
ModelProviderHandler,
3233
SlashCommandsInfoHandler,
3334
)
@@ -77,6 +78,7 @@ class AiExtension(ExtensionApp):
7778
(r"api/ai/config/?", GlobalConfigHandler),
7879
(r"api/ai/chats/slash_commands?", SlashCommandsInfoHandler),
7980
(r"api/ai/chats/autocomplete_options?", AutocompleteOptionsHandler),
81+
(r"api/ai/chats/stop_streaming?", InterruptStreamingHandler),
8082
(r"api/ai/providers?", ModelProviderHandler),
8183
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
8284
(r"api/ai/completion/inline/?", DefaultInlineCompletionHandler),
@@ -625,17 +627,23 @@ def _init_persona_manager(self, ychat: YChat) -> Optional[PersonaManager]:
625627
This method should not raise an exception. Upon encountering an
626628
exception, this method will catch it, log it, and return `None`.
627629
"""
628-
persona_manager: Optional[PersonaManager]
630+
persona_manager: Optional[PersonaManager] = None
629631

630632
try:
631633
config_manager = self.settings.get("jai_config_manager", None)
632634
assert config_manager and isinstance(config_manager, ConfigManager)
633635

636+
message_interrupted = self.settings.get("jai_message_interrupted", None)
637+
assert message_interrupted is not None and isinstance(
638+
message_interrupted, dict
639+
)
640+
634641
persona_manager = PersonaManager(
635642
ychat=ychat,
636643
config_manager=config_manager,
637644
event_loop=self.event_loop,
638645
log=self.log,
646+
message_interrupted=message_interrupted,
639647
)
640648
except Exception as e:
641649
# TODO: how to stop the extension when this fails

packages/jupyter-ai/jupyter_ai/handlers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,17 @@ def delete(self, api_key_name: str):
204204
raise HTTPError(500, str(e))
205205

206206

207+
class InterruptStreamingHandler(BaseAPIHandler):
208+
"""Interrupt a current message streaming"""
209+
210+
@web.authenticated
211+
def post(self):
212+
message_id = self.get_json_body().get("message_id")
213+
message_interrupted = self.settings.get("jai_message_interrupted")
214+
if message_id and message_id in message_interrupted.keys():
215+
message_interrupted[message_id].set()
216+
217+
207218
class SlashCommandsInfoHandler(BaseAPIHandler):
208219
"""List slash commands that are currently available to the user."""
209220

packages/jupyter-ai/jupyter_ai/personas/base_persona.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from abc import ABC, abstractmethod
23
from dataclasses import asdict
34
from logging import Logger
@@ -82,6 +83,10 @@ class BasePersona(ABC):
8283
Automatically set by `BasePersona`.
8384
"""
8485

86+
message_interrupted: dict[str, asyncio.Event]
87+
"""Dictionary mapping an agent message identifier to an asyncio Event
88+
which indicates if the message generation/streaming was interrupted."""
89+
8590
################################################
8691
# constructor
8792
################################################
@@ -92,11 +97,13 @@ def __init__(
9297
manager: "PersonaManager",
9398
config: ConfigManager,
9499
log: Logger,
100+
message_interrupted: dict[str, asyncio.Event],
95101
):
96102
self.ychat = ychat
97103
self.manager = manager
98104
self.config = config
99105
self.log = log
106+
self.message_interrupted = message_interrupted
100107
self.awareness = PersonaAwareness(
101108
ychat=self.ychat, log=self.log, user=self.as_user()
102109
)
@@ -221,14 +228,34 @@ async def stream_message(self, reply_stream: "AsyncIterator") -> None:
221228
- Automatically manages its awareness state to show writing status.
222229
"""
223230
stream_id: Optional[str] = None
224-
231+
stream_interrupted = False
225232
try:
226233
self.awareness.set_local_state_field("isWriting", True)
227234
async for chunk in reply_stream:
235+
if (
236+
stream_id
237+
and stream_id in self.message_interrupted.keys()
238+
and self.message_interrupted[stream_id].is_set()
239+
):
240+
try:
241+
# notify the model provider that streaming was interrupted
242+
# (this is essential to allow the model to stop generating)
243+
await reply_stream.athrow( # type:ignore[attr-defined]
244+
GenerationInterrupted()
245+
)
246+
except GenerationInterrupted:
247+
# do not let the exception bubble up in case if
248+
# the provider did not handle it
249+
pass
250+
stream_interrupted = True
251+
break
252+
228253
if not stream_id:
229254
stream_id = self.ychat.add_message(
230255
NewMessage(body="", sender=self.id)
231256
)
257+
self.message_interrupted[stream_id] = asyncio.Event()
258+
self.awareness.set_local_state_field("isWriting", stream_id)
232259

233260
assert stream_id
234261
self.ychat.update_message(
@@ -248,9 +275,29 @@ async def stream_message(self, reply_stream: "AsyncIterator") -> None:
248275
self.log.exception(e)
249276
finally:
250277
self.awareness.set_local_state_field("isWriting", False)
278+
if stream_id:
279+
# if stream was interrupted, add a tombstone
280+
if stream_interrupted:
281+
stream_tombstone = "\n\n(AI response stopped by user)"
282+
self.ychat.update_message(
283+
Message(
284+
id=stream_id,
285+
body=stream_tombstone,
286+
time=time(),
287+
sender=self.id,
288+
raw_time=False,
289+
),
290+
append=True,
291+
)
292+
if stream_id in self.message_interrupted.keys():
293+
del self.message_interrupted[stream_id]
251294

252295
def send_message(self, body: str) -> None:
253296
"""
254297
Sends a new message to the chat from this persona.
255298
"""
256299
self.ychat.add_message(NewMessage(body=body, sender=self.id))
300+
301+
302+
class GenerationInterrupted(asyncio.CancelledError):
303+
"""Exception raised when streaming is cancelled by the user"""

packages/jupyter-ai/jupyter_ai/personas/persona_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from logging import Logger
23
from time import time_ns
34
from typing import TYPE_CHECKING, ClassVar, Optional
@@ -37,11 +38,13 @@ def __init__(
3738
config_manager: ConfigManager,
3839
event_loop: "AbstractEventLoop",
3940
log: Logger,
41+
message_interrupted: dict[str, asyncio.Event],
4042
):
4143
self.ychat = ychat
4244
self.config_manager = config_manager
4345
self.event_loop = event_loop
4446
self.log = log
47+
self.message_interrupted = message_interrupted
4548

4649
if not isinstance(PersonaManager._persona_classes, list):
4750
self._init_persona_classes()
@@ -125,6 +128,7 @@ def _init_personas(self) -> dict[str, BasePersona]:
125128
manager=self,
126129
config=self.config_manager,
127130
log=self.log,
131+
message_interrupted=self.message_interrupted,
128132
)
129133
except Exception:
130134
self.log.exception(

packages/jupyter-ai/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
"@emotion/react": "^11.10.5",
6363
"@emotion/styled": "^11.10.5",
6464
"@jupyter-notebook/application": "^7.2.0",
65-
"@jupyter/chat": "^0.11.0",
65+
"@jupyter/chat": "^0.12.0",
6666
"@jupyterlab/application": "^4.2.0",
6767
"@jupyterlab/apputils": "^4.2.0",
6868
"@jupyterlab/codeeditor": "^4.2.0",

packages/jupyter-ai/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies = [
3737
# traitlets>=5.6 is required in JL4
3838
"traitlets>=5.6",
3939
"deepmerge>=2.0,<3",
40-
"jupyterlab-chat>=0.11.0,<0.12.0",
40+
"jupyterlab-chat>=0.12.0,<0.13.0",
4141
]
4242

4343
dynamic = ["version", "description", "authors", "urls", "keywords"]
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import {
2+
IChatModel,
3+
MessageFooterSectionProps,
4+
TooltippedButton
5+
} from '@jupyter/chat';
6+
import StopIcon from '@mui/icons-material/Stop';
7+
import React, { useEffect, useState } from 'react';
8+
import { requestAPI } from '../../handler';
9+
10+
/**
11+
* The stop button.
12+
*/
13+
export function StopButton(props: MessageFooterSectionProps): JSX.Element {
14+
const { message, model } = props;
15+
const [visible, setVisible] = useState<boolean>(false);
16+
const tooltip = 'Stop streaming';
17+
18+
useEffect(() => {
19+
const writerChanged = (_: IChatModel, writers: IChatModel.IWriter[]) => {
20+
const w = writers.filter(w => w.messageID === message.id);
21+
if (w.length > 0) {
22+
setVisible(true);
23+
} else {
24+
setVisible(false);
25+
}
26+
};
27+
28+
// Listen only the messages that are from a bot.
29+
if (
30+
message.sender.username !== model.user?.username &&
31+
message.sender.bot
32+
) {
33+
model.writersChanged?.connect(writerChanged);
34+
35+
// Check if the message is currently being edited.
36+
writerChanged(model, model.writers);
37+
}
38+
39+
return () => {
40+
model.writersChanged?.disconnect(writerChanged);
41+
};
42+
}, [model]);
43+
44+
const onClick = () => {
45+
// Post request to the stop streaming handler.
46+
requestAPI('chats/stop_streaming', {
47+
method: 'POST',
48+
body: JSON.stringify({
49+
message_id: message.id
50+
}),
51+
headers: {
52+
'Content-Type': 'application/json'
53+
}
54+
});
55+
};
56+
57+
return visible ? (
58+
<TooltippedButton
59+
onClick={onClick}
60+
tooltip={tooltip}
61+
buttonProps={{
62+
size: 'small',
63+
variant: 'contained',
64+
title: tooltip
65+
}}
66+
sx={{ display: visible ? 'inline-flex' : 'none' }}
67+
>
68+
<StopIcon />
69+
</TooltippedButton>
70+
) : (
71+
<></>
72+
);
73+
}

packages/jupyter-ai/src/index.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { INotebookShell } from '@jupyter-notebook/application';
2+
import { IMessageFooterRegistry } from '@jupyter/chat';
23
import {
34
JupyterFrontEnd,
45
JupyterFrontEndPlugin
@@ -16,6 +17,7 @@ import { SingletonLayout, Widget } from '@lumino/widgets';
1617

1718
import { chatCommandPlugins } from './chat-commands';
1819
import { completionPlugin } from './completions';
20+
import { StopButton } from './components/message-footer/stop-button';
1921
import { statusItemPlugin } from './status';
2022
import { IJaiCompletionProvider } from './tokens';
2123
import { buildErrorWidget } from './widgets/chat-error';
@@ -104,10 +106,23 @@ const plugin: JupyterFrontEndPlugin<void> = {
104106
}
105107
};
106108

109+
const stopStreaming: JupyterFrontEndPlugin<void> = {
110+
id: '@jupyter-ai/core:stop-streaming',
111+
autoStart: true,
112+
requires: [IMessageFooterRegistry],
113+
activate: (app: JupyterFrontEnd, registry: IMessageFooterRegistry) => {
114+
registry.addSection({
115+
component: StopButton,
116+
position: 'center'
117+
});
118+
}
119+
};
120+
107121
export default [
108122
plugin,
109123
statusItemPlugin,
110124
completionPlugin,
125+
stopStreaming,
111126
...chatCommandPlugins
112127
];
113128

0 commit comments

Comments
 (0)