Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 63 additions & 19 deletions jupyter_ai_router/extension.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import time
from jupyter_events import EventLogger
from jupyter_server.extension.application import ExtensionApp
Expand All @@ -7,21 +8,26 @@

from .router import MessageRouter

# Check jupyter-collaboration version for compatibility
try:
from jupyter_collaboration import __version__ as jupyter_collaboration_version

JCOLLAB_VERSION = int(jupyter_collaboration_version[0])
if JCOLLAB_VERSION >= 3:
from jupyter_server_ydoc.utils import JUPYTER_COLLABORATION_EVENTS_URI
else:
from jupyter_collaboration.utils import JUPYTER_COLLABORATION_EVENTS_URI
from jupyter_server_ydoc.utils import JUPYTER_COLLABORATION_EVENTS_URI
except ImportError:
# Fallback if jupyter-collaboration is not available
JUPYTER_COLLABORATION_EVENTS_URI = (
"https://events.jupyter.org/jupyter_collaboration"
)

# Define `JSD_PRESENT` to indicate whether `jupyter_server_documents` is
# installed in the current environment.
JSD_PRESENT = False
try:
import jupyter_server_documents
JSD_PRESENT = True
except ImportError:
pass

if TYPE_CHECKING:
from jupyterlab_chat.ychat import YChat


class RouterExtension(ExtensionApp):
"""
Expand All @@ -33,6 +39,8 @@ class RouterExtension(ExtensionApp):
(r"jupyter-ai-router/health/?", RouteHandler),
]

router: MessageRouter

def initialize_settings(self):
"""Initialize router settings and event listeners."""
start = time.time()
Expand Down Expand Up @@ -79,25 +87,61 @@ async def _on_chat_event(
# Connect chat to router
self.router.connect_chat(room_id, ychat)

async def _get_chat(self, room_id: str):
"""Get YChat instance for a room ID."""
async def _get_chat(self, room_id: str) -> YChat | None:
"""
Get YChat instance for a room ID.

Dispatches to either `_get_chat_jcollab()` or `_get_chat_jsd()` based on
whether `jupyter_server_documents` is installed.
"""

if JSD_PRESENT:
return await self._get_chat_jsd(room_id)
else:
return await self._get_chat_jcollab(room_id)

async def _get_chat_jcollab(self, room_id: str) -> YChat | None:
"""
Method used to retrieve the `YChat` instance for a given room when
`jupyter_server_documents` **is not** installed.
"""
if not self.serverapp:
return None

try:
if JCOLLAB_VERSION >= 3:
collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"]
document = await collaboration.get_document(room_id=room_id, copy=False)
else:
collaboration = self.serverapp.web_app.settings["jupyter_collaboration"]
server = collaboration.ywebsocket_server
room = await server.get_room(room_id)
document = room._document

collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"]
document = await collaboration.get_document(room_id=room_id, copy=False)
return document
except Exception as e:
self.log.error(f"Error getting chat document for {room_id}: {e}")
return None

async def _get_chat_jsd(self, room_id: str) -> YChat | None:
"""
Method used to retrieve the `YChat` instance for a given room when
`jupyter_server_documents` **is** installed.

This method uniquely attaches a callback which is fired whenever the
`YChat` is reset.
"""
if not self.serverapp:
return None

try:
jcollab_api = self.serverapp.web_app.settings["jupyter_server_ydoc"]
yroom_manager = jcollab_api.yroom_manager
yroom = yroom_manager.get_room(room_id)

def _on_ychat_reset(new_ychat: YChat):
self.router._on_chat_reset(room_id, new_ychat)

ychat = await yroom.get_jupyter_ydoc(on_reset=_on_ychat_reset)
return ychat
except Exception as e:
self.log.error(f"Error getting chat document for {room_id}: {e}")
return None



async def stop_extension(self):
"""Clean up router when extension stops."""
Expand Down
27 changes: 27 additions & 0 deletions jupyter_ai_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs):
self.chat_init_observers: List[Callable[[str, "YChat"], Any]] = []
self.slash_cmd_observers: Dict[str, Dict[str, List[Callable[[str, str, Message], Any]]]] = {}
self.chat_msg_observers: Dict[str, List[Callable[[str, Message], Any]]] = {}
self.chat_reset_observers: List[Callable[[str, "YChat"], Any]] = []

# Active chat rooms
self.active_chats: Dict[str, "YChat"] = {}
Expand All @@ -71,7 +72,18 @@ def observe_chat_init(self, callback: Callable[[str, "YChat"], Any]) -> None:
"""
self.chat_init_observers.append(callback)
self.log.info("Registered new chat initialization callback")

def observe_chat_reset(self, callback: Callable[[str, "YChat"], Any]) -> None:
"""
Register a callback for when a `YChat` document is reset. This will only
occur if `jupyter_server_documents` is installed.

Args:
callback: Function called with (room_id: str, new_ychat: YChat) when chat resets
"""
self.chat_reset_observers.append(callback)
self.log.info("Registered new chat reset callback")

def observe_slash_cmd_msg(
self, room_id: str, command_pattern: str, callback: Callable[[str, str, Message], Any]
) -> None:
Expand Down Expand Up @@ -230,6 +242,21 @@ def _notify_msg_observers(self, room_id: str, message: Message) -> None:
callback(room_id, message)
except Exception as e:
self.log.error(f"Message observer error for {room_id}: {e}")

def _on_chat_reset(self, room_id, ychat: "YChat") -> None:
"""
Method to call when the YChat undergoes a document reset, e.g. when the
`.chat` file is modified directly on disk.

NOTE: Document resets will only occur when `jupyter_server_documents` is
installed.
"""
self.log.warning(f"Detected `YChat` document reset in room '{room_id}'.")
for callback in self.chat_reset_observers:
try:
callback(room_id, ychat)
except Exception as e:
self.log.error(f"Reset chat observer error for {room_id}: {e}")

def cleanup(self) -> None:
"""Clean up router resources."""
Expand Down
Loading