diff --git a/jupyter_ai_router/extension.py b/jupyter_ai_router/extension.py index 5fb9b40..d666f38 100644 --- a/jupyter_ai_router/extension.py +++ b/jupyter_ai_router/extension.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import TYPE_CHECKING import time from jupyter_events import EventLogger from jupyter_server.extension.application import ExtensionApp @@ -7,21 +8,26 @@ from .router import MessageRouter -# Check jupyter-collaboration version for compatibility try: - from jupyter_collaboration import __version__ as jupyter_collaboration_version - - JCOLLAB_VERSION = int(jupyter_collaboration_version[0]) - if JCOLLAB_VERSION >= 3: - from jupyter_server_ydoc.utils import JUPYTER_COLLABORATION_EVENTS_URI - else: - from jupyter_collaboration.utils import JUPYTER_COLLABORATION_EVENTS_URI + from jupyter_server_ydoc.utils import JUPYTER_COLLABORATION_EVENTS_URI except ImportError: # Fallback if jupyter-collaboration is not available JUPYTER_COLLABORATION_EVENTS_URI = ( "https://events.jupyter.org/jupyter_collaboration" ) +# Define `JSD_PRESENT` to indicate whether `jupyter_server_documents` is +# installed in the current environment. +JSD_PRESENT = False +try: + import jupyter_server_documents + JSD_PRESENT = True +except ImportError: + pass + +if TYPE_CHECKING: + from jupyterlab_chat.ychat import YChat + class RouterExtension(ExtensionApp): """ @@ -33,6 +39,8 @@ class RouterExtension(ExtensionApp): (r"jupyter-ai-router/health/?", RouteHandler), ] + router: MessageRouter + def initialize_settings(self): """Initialize router settings and event listeners.""" start = time.time() @@ -79,25 +87,61 @@ async def _on_chat_event( # Connect chat to router self.router.connect_chat(room_id, ychat) - async def _get_chat(self, room_id: str): - """Get YChat instance for a room ID.""" + async def _get_chat(self, room_id: str) -> YChat | None: + """ + Get YChat instance for a room ID. + + Dispatches to either `_get_chat_jcollab()` or `_get_chat_jsd()` based on + whether `jupyter_server_documents` is installed. + """ + + if JSD_PRESENT: + return await self._get_chat_jsd(room_id) + else: + return await self._get_chat_jcollab(room_id) + + async def _get_chat_jcollab(self, room_id: str) -> YChat | None: + """ + Method used to retrieve the `YChat` instance for a given room when + `jupyter_server_documents` **is not** installed. + """ if not self.serverapp: return None try: - if JCOLLAB_VERSION >= 3: - collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"] - document = await collaboration.get_document(room_id=room_id, copy=False) - else: - collaboration = self.serverapp.web_app.settings["jupyter_collaboration"] - server = collaboration.ywebsocket_server - room = await server.get_room(room_id) - document = room._document - + collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"] + document = await collaboration.get_document(room_id=room_id, copy=False) return document except Exception as e: self.log.error(f"Error getting chat document for {room_id}: {e}") return None + + async def _get_chat_jsd(self, room_id: str) -> YChat | None: + """ + Method used to retrieve the `YChat` instance for a given room when + `jupyter_server_documents` **is** installed. + + This method uniquely attaches a callback which is fired whenever the + `YChat` is reset. + """ + if not self.serverapp: + return None + + try: + jcollab_api = self.serverapp.web_app.settings["jupyter_server_ydoc"] + yroom_manager = jcollab_api.yroom_manager + yroom = yroom_manager.get_room(room_id) + + def _on_ychat_reset(new_ychat: YChat): + self.router._on_chat_reset(room_id, new_ychat) + + ychat = await yroom.get_jupyter_ydoc(on_reset=_on_ychat_reset) + return ychat + except Exception as e: + self.log.error(f"Error getting chat document for {room_id}: {e}") + return None + + async def stop_extension(self): """Clean up router when extension stops.""" diff --git a/jupyter_ai_router/router.py b/jupyter_ai_router/router.py index b55b2fc..7e1c35a 100644 --- a/jupyter_ai_router/router.py +++ b/jupyter_ai_router/router.py @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs): self.chat_init_observers: List[Callable[[str, "YChat"], Any]] = [] self.slash_cmd_observers: Dict[str, Dict[str, List[Callable[[str, str, Message], Any]]]] = {} self.chat_msg_observers: Dict[str, List[Callable[[str, Message], Any]]] = {} + self.chat_reset_observers: List[Callable[[str, "YChat"], Any]] = [] # Active chat rooms self.active_chats: Dict[str, "YChat"] = {} @@ -71,7 +72,18 @@ def observe_chat_init(self, callback: Callable[[str, "YChat"], Any]) -> None: """ self.chat_init_observers.append(callback) self.log.info("Registered new chat initialization callback") + + def observe_chat_reset(self, callback: Callable[[str, "YChat"], Any]) -> None: + """ + Register a callback for when a `YChat` document is reset. This will only + occur if `jupyter_server_documents` is installed. + Args: + callback: Function called with (room_id: str, new_ychat: YChat) when chat resets + """ + self.chat_reset_observers.append(callback) + self.log.info("Registered new chat reset callback") + def observe_slash_cmd_msg( self, room_id: str, command_pattern: str, callback: Callable[[str, str, Message], Any] ) -> None: @@ -230,6 +242,21 @@ def _notify_msg_observers(self, room_id: str, message: Message) -> None: callback(room_id, message) except Exception as e: self.log.error(f"Message observer error for {room_id}: {e}") + + def _on_chat_reset(self, room_id, ychat: "YChat") -> None: + """ + Method to call when the YChat undergoes a document reset, e.g. when the + `.chat` file is modified directly on disk. + + NOTE: Document resets will only occur when `jupyter_server_documents` is + installed. + """ + self.log.warning(f"Detected `YChat` document reset in room '{room_id}'.") + for callback in self.chat_reset_observers: + try: + callback(room_id, ychat) + except Exception as e: + self.log.error(f"Reset chat observer error for {room_id}: {e}") def cleanup(self) -> None: """Clean up router resources."""