2
2
import re
3
3
import time
4
4
import types
5
+ from asyncio import get_event_loop_policy
5
6
from functools import partial
6
- from typing import Dict
7
+ from typing import TYPE_CHECKING , Dict , Optional
7
8
8
9
import traitlets
9
10
from dask .distributed import Client as DaskClient
12
13
from jupyter_ai_magics .utils import get_em_providers , get_lm_providers
13
14
from jupyter_events import EventLogger
14
15
from jupyter_server .extension .application import ExtensionApp
15
- from jupyter_server .utils import url_path_join
16
16
from jupyterlab_chat .models import Message
17
17
from jupyterlab_chat .ychat import YChat
18
18
from pycrdt import ArrayEvent
22
22
from .chat_handlers .base import BaseChatHandler
23
23
from .completions .handlers import DefaultInlineCompletionHandler
24
24
from .config_manager import ConfigManager
25
- from .constants import BOT
26
25
from .context_providers import BaseCommandContextProvider , FileContextProvider
27
26
from .handlers import (
28
27
ApiKeysHandler ,
33
32
SlashCommandsInfoHandler ,
34
33
)
35
34
from .history import YChatHistory
35
+ from .personas import PersonaManager
36
+
37
+ if TYPE_CHECKING :
38
+ from asyncio import AbstractEventLoop
36
39
37
40
from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
38
41
__version__ as jupyter_collaboration_version ,
@@ -244,6 +247,13 @@ def initialize(self):
244
247
schema_id = JUPYTER_COLLABORATION_EVENTS_URI , listener = self .connect_chat
245
248
)
246
249
250
+ @property
251
+ def event_loop (self ) -> "AbstractEventLoop" :
252
+ """
253
+ Returns a reference to the asyncio event loop.
254
+ """
255
+ return get_event_loop_policy ().get_event_loop ()
256
+
247
257
async def connect_chat (
248
258
self , logger : EventLogger , schema_id : str , data : dict
249
259
) -> None :
@@ -264,17 +274,19 @@ async def connect_chat(
264
274
if ychat is None :
265
275
return
266
276
267
- # Add the bot user to the chat document awareness.
268
- BOT ["avatar_url" ] = url_path_join (
269
- self .settings .get ("base_url" , "/" ), "api/ai/static/jupyternaut.svg"
270
- )
271
- if ychat .awareness is not None :
272
- ychat .awareness .set_local_state_field ("user" , BOT )
273
-
274
277
# initialize chat handlers for new chat
275
278
self .chat_handlers_by_room [room_id ] = self ._init_chat_handlers (ychat )
276
279
277
- callback = partial (self .on_change , room_id )
280
+ # initialize persona manager
281
+ persona_manager = self ._init_persona_manager (ychat )
282
+ if not persona_manager :
283
+ self .log .error (
284
+ "Jupyter AI was unable to initialize its AI personas. They are not available for use in chat until this error is resolved. "
285
+ + "Please verify your configuration and open a new issue on GitHub if this error persists."
286
+ )
287
+ return
288
+
289
+ callback = partial (self .on_change , room_id , persona_manager )
278
290
ychat .ymessages .observe (callback )
279
291
280
292
async def get_chat (self , room_id : str ) -> YChat :
@@ -301,21 +313,26 @@ async def get_chat(self, room_id: str) -> YChat:
301
313
self .ychats_by_room [room_id ] = document
302
314
return document
303
315
304
- def on_change (self , room_id : str , events : ArrayEvent ) -> None :
316
+ def on_change (
317
+ self , room_id : str , persona_manager : PersonaManager , events : ArrayEvent
318
+ ) -> None :
305
319
assert self .serverapp
306
320
307
321
for change in events .delta : # type:ignore[attr-defined]
308
322
if not "insert" in change .keys ():
309
323
continue
310
- messages = change ["insert" ]
311
- for message_dict in messages :
312
- message = Message (** message_dict )
313
- if message .sender == BOT ["username" ] or message .raw_time :
314
- continue
315
324
316
- self .serverapp .io_loop .asyncio_loop .create_task ( # type:ignore[attr-defined]
317
- self .route_human_message (room_id , message )
318
- )
325
+ # the "if not m['raw_time']" clause is necessary because every new
326
+ # message triggers 2 events, one with `raw_time` set to `True` and
327
+ # another with `raw_time` set to `False` milliseconds later.
328
+ # we should explore fixing this quirk in Jupyter Chat.
329
+ #
330
+ # Ref: https://github.com/jupyterlab/jupyter-chat/issues/212
331
+ new_messages = [
332
+ Message (** m ) for m in change ["insert" ] if not m .get ("raw_time" , False )
333
+ ]
334
+ for new_message in new_messages :
335
+ persona_manager .route_message (new_message )
319
336
320
337
async def route_human_message (self , room_id : str , message : Message ):
321
338
"""
@@ -400,18 +417,15 @@ def initialize_settings(self):
400
417
401
418
self .log .info (f"Registered { self .name } server extension" )
402
419
403
- # get reference to event loop
404
- # `asyncio.get_event_loop()` is deprecated in Python 3.11+, in favor of
405
- # the more readable `asyncio.get_event_loop_policy().get_event_loop()`.
406
- # it's easier to just reference the loop directly.
407
- loop = self .serverapp .io_loop .asyncio_loop
408
- self .settings ["jai_event_loop" ] = loop
420
+ self .settings ["jai_event_loop" ] = self .event_loop
409
421
410
422
# We cannot instantiate the Dask client directly here because it
411
423
# requires the event loop to be running on init. So instead we schedule
412
424
# this as a task that is run as soon as the loop starts, and pass
413
425
# consumers a Future that resolves to the Dask client when awaited.
414
- self .settings ["dask_client_future" ] = loop .create_task (self ._get_dask_client ())
426
+ self .settings ["dask_client_future" ] = self .event_loop .create_task (
427
+ self ._get_dask_client ()
428
+ )
415
429
416
430
# Create empty context providers dict to be filled later.
417
431
# This is created early to use as kwargs for chat handlers.
@@ -456,10 +470,7 @@ async def _stop_extension(self):
456
470
457
471
def _init_chat_handlers (self , ychat : YChat ) -> Dict [str , BaseChatHandler ]:
458
472
"""
459
- Initializes a set of chat handlers. May accept a YChat instance for
460
- collaborative chats.
461
-
462
- TODO: Make `ychat` required once Jupyter Chat migration is complete.
473
+ Initializes a set of chat handlers for a given `YChat` instance.
463
474
"""
464
475
assert self .serverapp
465
476
@@ -606,3 +617,32 @@ def _init_context_providers(self):
606
617
** context_providers_kwargs
607
618
)
608
619
self .log .info (f"Registered context provider `{ context_provider .id } `." )
620
+
621
+ def _init_persona_manager (self , ychat : YChat ) -> Optional [PersonaManager ]:
622
+ """
623
+ Initializes a `PersonaManager` instance scoped to a `YChat`.
624
+
625
+ This method should not raise an exception. Upon encountering an
626
+ exception, this method will catch it, log it, and return `None`.
627
+ """
628
+ persona_manager : Optional [PersonaManager ]
629
+
630
+ try :
631
+ config_manager = self .settings .get ("jai_config_manager" , None )
632
+ assert config_manager and isinstance (config_manager , ConfigManager )
633
+
634
+ persona_manager = PersonaManager (
635
+ ychat = ychat ,
636
+ config_manager = config_manager ,
637
+ event_loop = self .event_loop ,
638
+ log = self .log ,
639
+ )
640
+ except Exception as e :
641
+ # TODO: how to stop the extension when this fails
642
+ # also why do uncaught exceptions produce an empty error log in Jupyter Server?
643
+ self .log .error (
644
+ f"Unable to initialize PersonaManager in YChat with ID '{ ychat .get_id ()} ' due to an exception printed below."
645
+ )
646
+ self .log .exception (e )
647
+ finally :
648
+ return persona_manager
0 commit comments