diff --git a/README.md b/README.md index f70fe1f..4a298db 100644 --- a/README.md +++ b/README.md @@ -2,28 +2,46 @@ [![Github Actions Status](https://github.com/jupyter-ai-contrib/jupyter-ai-router/workflows/Build/badge.svg)](https://github.com/jupyter-ai-contrib/jupyter-ai-router/actions/workflows/build.yml) -Core routing layer of Jupyter AI +Core message routing layer for Jupyter AI -This extension is composed of a Python package named `jupyter_ai_router` -for the server extension and a NPM package named `@jupyter-ai/router` -for the frontend extension. +This extension provides the foundational message routing functionality for Jupyter AI. It automatically detects new chat sessions and routes messages to registered callbacks based on message type (slash commands vs regular messages). Extensions can register callbacks to handle specific chat events without needing to manage chat lifecycle directly. -## QUICK START +## Usage -Everything that follows after this section was from the extension template. We -will need to revise the rest of this README. +### Basic MessageRouter Setup -Development install: +```python +# The router is available in other extensions via settings +router = self.serverapp.web_app.settings.get("jupyter-ai", {}).get("router") +# Register callbacks for different event types +def on_new_chat(room_id: str, ychat: YChat): + print(f"New chat connected: {room_id}") + +def on_slash_command(room_id: str, message: Message): + print(f"Slash command in {room_id}: {message.body}") + +def on_regular_message(room_id: str, message: Message):` + print(f"Regular message in {room_id}: {message.body}") + +# Register the callbacks +router.observe_chat_init(on_new_chat) +router.observe_slash_cmd_msg("room-id", on_slash_command) +router.observe_chat_msg("room-id", on_regular_message) ``` -micromamba install uv jupyterlab nodejs=22 -jlpm -jlpm dev:install -``` -## Requirements +### Message Flow + +1. **Router detects new chats** - Automatically listens for chat room initialization events +2. **Router connects chats** - Establishes observers on YChat message streams +3. **Router routes messages** - Calls appropriate callbacks based on message type (slash vs regular) +4. **Extensions respond** - Your callbacks receive room_id and message data + +### Available Methods -- JupyterLab >= 4.0.0 +- `observe_chat_init(callback)` - Called when new chat sessions are initialized with `(room_id, ychat)` +- `observe_slash_cmd_msg(room_id, callback)` - Called for messages starting with `/` in a specific room +- `observe_chat_msg(room_id, callback)` - Called for regular (non-slash) messages in a specific room ## Install diff --git a/jupyter_ai_router/__init__.py b/jupyter_ai_router/__init__.py index f24b583..a846e2d 100644 --- a/jupyter_ai_router/__init__.py +++ b/jupyter_ai_router/__init__.py @@ -7,7 +7,8 @@ import warnings warnings.warn("Importing 'jupyter_ai_router' outside a proper installation.") __version__ = "dev" -from .handlers import setup_handlers + +from .extension import RouterExtension def _jupyter_labextension_paths(): @@ -18,19 +19,4 @@ def _jupyter_labextension_paths(): def _jupyter_server_extension_points(): - return [{ - "module": "jupyter_ai_router" - }] - - -def _load_jupyter_server_extension(server_app): - """Registers the API handler to receive HTTP requests from the frontend extension. - - Parameters - ---------- - server_app: jupyterlab.labapp.LabApp - JupyterLab application instance - """ - setup_handlers(server_app.web_app) - name = "jupyter_ai_router" - server_app.log.info(f"Registered {name} server extension") + return [{"module": "jupyter_ai_router", "app": RouterExtension}] diff --git a/jupyter_ai_router/extension.py b/jupyter_ai_router/extension.py new file mode 100644 index 0000000..5fb9b40 --- /dev/null +++ b/jupyter_ai_router/extension.py @@ -0,0 +1,108 @@ +from __future__ import annotations +import time +from jupyter_events import EventLogger +from jupyter_server.extension.application import ExtensionApp + +from jupyter_ai_router.handlers import RouteHandler + +from .router import MessageRouter + +# Check jupyter-collaboration version for compatibility +try: + from jupyter_collaboration import __version__ as jupyter_collaboration_version + + JCOLLAB_VERSION = int(jupyter_collaboration_version[0]) + if JCOLLAB_VERSION >= 3: + from jupyter_server_ydoc.utils import JUPYTER_COLLABORATION_EVENTS_URI + else: + from jupyter_collaboration.utils import JUPYTER_COLLABORATION_EVENTS_URI +except ImportError: + # Fallback if jupyter-collaboration is not available + JUPYTER_COLLABORATION_EVENTS_URI = ( + "https://events.jupyter.org/jupyter_collaboration" + ) + + +class RouterExtension(ExtensionApp): + """ + Jupyter AI Router Extension + """ + + name = "jupyter_ai_router" + handlers = [ + (r"jupyter-ai-router/health/?", RouteHandler), + ] + + def initialize_settings(self): + """Initialize router settings and event listeners.""" + start = time.time() + + # Create MessageRouter instance + self.router = MessageRouter(parent=self) + + # Make router available to other extensions + if "jupyter-ai" not in self.settings: + self.settings["jupyter-ai"] = {} + self.settings["jupyter-ai"]["router"] = self.router + + # Listen for new chat room events + if self.serverapp is not None: + self.event_logger = self.serverapp.web_app.settings["event_logger"] + self.event_logger.add_listener( + schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self._on_chat_event + ) + + elapsed = time.time() - start + self.log.info(f"Initialized RouterExtension in {elapsed:.2f}s") + + async def _on_chat_event( + self, logger: EventLogger, schema_id: str, data: dict + ) -> None: + """Handle chat room events and connect new chats to router.""" + # Only handle chat room initialization events + if not ( + data["room"].startswith("text:chat:") + and data["action"] == "initialize" + and data["msg"] == "Room initialized" + ): + return + + room_id = data["room"] + self.log.info(f"New chat room detected: {room_id}") + + # Get YChat document for the room + ychat = await self._get_chat(room_id) + if ychat is None: + self.log.error(f"Failed to get YChat for room {room_id}") + return + + # Connect chat to router + self.router.connect_chat(room_id, ychat) + + async def _get_chat(self, room_id: str): + """Get YChat instance for a room ID.""" + if not self.serverapp: + return None + + try: + if JCOLLAB_VERSION >= 3: + collaboration = self.serverapp.web_app.settings["jupyter_server_ydoc"] + document = await collaboration.get_document(room_id=room_id, copy=False) + else: + collaboration = self.serverapp.web_app.settings["jupyter_collaboration"] + server = collaboration.ywebsocket_server + room = await server.get_room(room_id) + document = room._document + + return document + except Exception as e: + self.log.error(f"Error getting chat document for {room_id}: {e}") + return None + + async def stop_extension(self): + """Clean up router when extension stops.""" + try: + if hasattr(self, "router"): + self.router.cleanup() + except Exception as e: + self.log.error(f"Error during router cleanup: {e}") diff --git a/jupyter_ai_router/handlers.py b/jupyter_ai_router/handlers.py index 7b59d65..55a3576 100644 --- a/jupyter_ai_router/handlers.py +++ b/jupyter_ai_router/handlers.py @@ -1,7 +1,6 @@ import json from jupyter_server.base.handlers import APIHandler -from jupyter_server.utils import url_path_join import tornado class RouteHandler(APIHandler): @@ -11,14 +10,5 @@ class RouteHandler(APIHandler): @tornado.web.authenticated def get(self): self.finish(json.dumps({ - "data": "This is /jupyter-ai-router/get-example endpoint!" + "data": "JupyterLab extension @jupyter-ai/router is activated!" })) - - -def setup_handlers(web_app): - host_pattern = ".*$" - - base_url = web_app.settings["base_url"] - route_pattern = url_path_join(base_url, "jupyter-ai-router", "get-example") - handlers = [(route_pattern, RouteHandler)] - web_app.add_handlers(host_pattern, handlers) diff --git a/jupyter_ai_router/router.py b/jupyter_ai_router/router.py new file mode 100644 index 0000000..a4146fa --- /dev/null +++ b/jupyter_ai_router/router.py @@ -0,0 +1,204 @@ +""" +MessageRouter that manages message routing with callbacks. + +This module provides a MessageRouter that: +- Handles new chat connections +- Routes slash commands and regular messages via callbacks +- Manages lifecycle and cleanup +""" + +from typing import Any, Callable, Dict, List, TYPE_CHECKING +from functools import partial +from jupyterlab_chat.models import Message +from pycrdt import ArrayEvent +from traitlets.config import LoggingConfigurable + +if TYPE_CHECKING: + from jupyterlab_chat.ychat import YChat + +from .utils import get_first_word + + +class MessageRouter(LoggingConfigurable): + """ + Router that manages ychat message routing. + + The Router provides three callback points: + 1. When new chats are initialized + 2. When slash commands are received + 3. When regular (non-slash) messages are received + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Callback lists + self.chat_init_observers: List[Callable[[str, "YChat"], Any]] = [] + self.slash_cmd_observers: Dict[str, List[Callable[[str, Message], Any]]] = {} + self.chat_msg_observers: Dict[str, List[Callable[[str, Message], Any]]] = {} + + # Active chat rooms + self.active_chats: Dict[str, "YChat"] = {} + + # Root observers for keeping track of incoming messages + self.message_observers: Dict[str, Callable] = {} + + def observe_chat_init(self, callback: Callable[[str, "YChat"], Any]) -> None: + """ + Register a callback for when new chats are initialized. + + Args: + callback: Function called with (room_id: str, ychat: YChat) when chat connects + """ + self.chat_init_observers.append(callback) + self.log.info("Registered new chat initialization callback") + + def observe_slash_cmd_msg( + self, room_id: str, callback: Callable[[str, Message], Any] + ) -> None: + """ + Register a callback for when slash commands are received. + + Args: + callback: Function called with (room_id: str, message: Message) for slash commands + """ + if room_id not in self.slash_cmd_observers: + self.slash_cmd_observers[room_id] = [] + + self.slash_cmd_observers[room_id].append(callback) + self.log.info("Registered slash command callback") + + def observe_chat_msg( + self, room_id: str, callback: Callable[[str, Message], Any] + ) -> None: + """ + Register a callback for when regular (non-slash) messages are received. + + Args: + callback: Function called with (room_id: str, message: Message) for regular messages + """ + if room_id not in self.chat_msg_observers: + self.chat_msg_observers[room_id] = [] + + self.chat_msg_observers[room_id].append(callback) + self.log.info("Registered message callback") + + def connect_chat(self, room_id: str, ychat: "YChat") -> None: + """ + Connect a new chat session to the router. + + Args: + room_id: Unique identifier for the chat room + ychat: YChat instance for the room + """ + if room_id in self.active_chats: + self.log.warning(f"Chat {room_id} already connected to router") + return + + self.active_chats[room_id] = ychat + + # Set up message observer + callback = partial(self._on_message_change, room_id, ychat) + ychat.ymessages.observe(callback) + self.message_observers[room_id] = callback + + self.log.info(f"Connected chat {room_id} to router") + + # Notify new chat observers + self._notify_chat_init_observers(room_id, ychat) + + def disconnect_chat(self, room_id: str) -> None: + """ + Disconnect a chat session from the router. + + Args: + room_id: Unique identifier for the chat room + """ + if room_id not in self.active_chats: + return + + # Remove message observer + if room_id in self.message_observers: + ychat = self.active_chats[room_id] + try: + ychat.ymessages.unobserve(self.message_observers[room_id]) + except Exception as e: + self.log.warning(f"Failed to unobserve chat {room_id}: {e}") + del self.message_observers[room_id] + + del self.active_chats[room_id] + self.log.info(f"Disconnected chat {room_id} from router") + + def _on_message_change( + self, room_id: str, ychat: "YChat", events: ArrayEvent + ) -> None: + """Handle incoming messages from YChat.""" + for change in events.delta: # type: ignore[attr-defined] + if "insert" not in change.keys(): + continue + + # Process new messages (filter out raw_time duplicates) + new_messages = [ + Message(**m) for m in change["insert"] if not m.get("raw_time", False) + ] + + for message in new_messages: + self._route_message(room_id, message) + + def _route_message(self, room_id: str, message: Message) -> None: + """ + Route an incoming message to appropriate observers. + + Args: + room_id: The chat room ID + message: The message to route + """ + first_word = get_first_word(message.body) + + # Check if it's a slash command + if first_word and first_word.startswith("/"): + self._notify_slash_cmd_observers(room_id, message) + else: + self._notify_msg_observers(room_id, message) + + def _notify_chat_init_observers(self, room_id: str, ychat: "YChat") -> None: + """Notify all new chat observers.""" + for callback in self.chat_init_observers: + try: + callback(room_id, ychat) + except Exception as e: + self.log.error(f"New chat observer error for {room_id}: {e}") + + def _notify_slash_cmd_observers(self, room_id: str, message: Message) -> None: + """Notify all slash command observers.""" + callbacks = self.slash_cmd_observers.get(room_id, []) + for callback in callbacks: + try: + callback(room_id, message) + except Exception as e: + self.log.error(f"Slash command observer error for {room_id}: {e}") + + def _notify_msg_observers(self, room_id: str, message: Message) -> None: + """Notify all message observers.""" + callbacks = self.chat_msg_observers.get(room_id, []) + for callback in callbacks: + try: + callback(room_id, message) + except Exception as e: + self.log.error(f"Message observer error for {room_id}: {e}") + + def cleanup(self) -> None: + """Clean up router resources.""" + self.log.info("Cleaning up MessageRouter...") + + # Disconnect all chats + room_ids = list(self.active_chats.keys()) + for room_id in room_ids: + self.disconnect_chat(room_id) + + # Clear callbacks + self.chat_init_observers.clear() + self.slash_cmd_observers.clear() + self.chat_msg_observers.clear() + + self.log.info("MessageRouter cleanup complete") diff --git a/jupyter_ai_router/tests/test_handlers.py b/jupyter_ai_router/tests/test_handlers.py index e4f8317..ff9a57b 100644 --- a/jupyter_ai_router/tests/test_handlers.py +++ b/jupyter_ai_router/tests/test_handlers.py @@ -1,13 +1,19 @@ -import json +""" +Tests for extension handlers (currently none). + +This extension provides server-side functionality only +and does not expose HTTP endpoints. +""" +import json -async def test_get_example(jp_fetch): +async def test_get_health(jp_fetch): # When - response = await jp_fetch("jupyter-ai-router", "get-example") + response = await jp_fetch("jupyter-ai-router", "health") # Then assert response.code == 200 payload = json.loads(response.body) assert payload == { - "data": "This is /jupyter-ai-router/get-example endpoint!" + "data": "JupyterLab extension @jupyter-ai/router is activated!" } \ No newline at end of file diff --git a/jupyter_ai_router/tests/test_message_router.py b/jupyter_ai_router/tests/test_message_router.py new file mode 100644 index 0000000..6e16b31 --- /dev/null +++ b/jupyter_ai_router/tests/test_message_router.py @@ -0,0 +1,122 @@ +""" +Tests for MessageRouter functionality. +""" + +import pytest +from unittest.mock import Mock, MagicMock +from jupyterlab_chat.models import Message +from jupyterlab_chat.ychat import YChat +from jupyter_ai_router.router import MessageRouter +from jupyter_ai_router.utils import get_first_word, is_persona + + +class TestUtils: + """Test utility functions.""" + + def test_get_first_word_normal(self): + """Test getting first word from normal string.""" + assert get_first_word("hello world") == "hello" + assert get_first_word(" hello world ") == "hello" + assert get_first_word("/refresh-personas") == "/refresh-personas" + + def test_get_first_word_edge_cases(self): + """Test edge cases for get_first_word.""" + assert get_first_word("") is None + assert get_first_word(" ") is None + assert get_first_word("single") == "single" + + def test_is_persona(self): + """Test persona username detection.""" + assert is_persona("jupyter-ai-personas::jupyter_ai::JupyternautPersona") is True + assert is_persona("human_user") is False + assert is_persona("jupyter-ai-personas::custom::MyPersona") is True + + +class TestMessageRouter: + """Test MessageRouter functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.router = MessageRouter() + self.mock_chat_init_callback = Mock() + self.mock_slash_cmd_callback = Mock() + self.mock_msg_callback = Mock() + self.mock_ychat = Mock(spec=YChat) + self.mock_ychat.ymessages = Mock() + + def test_router_initialization(self): + """Test router initializes correctly.""" + router = MessageRouter() + assert len(router.chat_init_observers) == 0 + assert len(router.slash_cmd_observers) == 0 + assert len(router.chat_msg_observers) == 0 + assert len(router.active_chats) == 0 + + def test_observe_chat_init(self): + """Test registering chat init callback.""" + self.router.observe_chat_init(self.mock_chat_init_callback) + assert self.mock_chat_init_callback in self.router.chat_init_observers + + def test_observe_slash_cmd_msg(self): + """Test registering slash command callback.""" + room_id = "test-room" + self.router.observe_slash_cmd_msg(room_id, self.mock_slash_cmd_callback) + assert self.mock_slash_cmd_callback in self.router.slash_cmd_observers[room_id] + + def test_observe_chat_msg(self): + """Test registering regular message callback.""" + room_id = "test-room" + self.router.observe_chat_msg(room_id, self.mock_msg_callback) + assert self.mock_msg_callback in self.router.chat_msg_observers[room_id] + + def test_connect_chat(self): + """Test connecting a chat to the router.""" + room_id = "test-room" + self.router.observe_chat_init(self.mock_chat_init_callback) + + self.router.connect_chat(room_id, self.mock_ychat) + + # Should store the chat and call init observers + assert room_id in self.router.active_chats + assert self.router.active_chats[room_id] == self.mock_ychat + self.mock_chat_init_callback.assert_called_once_with(room_id, self.mock_ychat) + + def test_disconnect_chat(self): + """Test disconnecting a chat from the router.""" + room_id = "test-room" + self.router.connect_chat(room_id, self.mock_ychat) + + self.router.disconnect_chat(room_id) + + # Should remove the chat + assert room_id not in self.router.active_chats + + def test_message_routing(self): + """Test message routing to appropriate callbacks.""" + room_id = "test-room" + self.router.observe_slash_cmd_msg(room_id, self.mock_slash_cmd_callback) + self.router.observe_chat_msg(room_id, self.mock_msg_callback) + + # Test slash command routing + slash_msg = Message(id="1", body="/test command", sender="user", time=123) + self.router._route_message(room_id, slash_msg) + self.mock_slash_cmd_callback.assert_called_once_with(room_id, slash_msg) + + # Test regular message routing + regular_msg = Message(id="2", body="Hello world", sender="user", time=124) + self.router._route_message(room_id, regular_msg) + self.mock_msg_callback.assert_called_once_with(room_id, regular_msg) + + def test_cleanup(self): + """Test router cleanup.""" + room_id = "test-room" + self.router.connect_chat(room_id, self.mock_ychat) + self.router.observe_chat_init(self.mock_chat_init_callback) + + self.router.cleanup() + + # Should clear all observers and active chats + assert len(self.router.active_chats) == 0 + assert len(self.router.chat_init_observers) == 0 + assert len(self.router.slash_cmd_observers) == 0 + assert len(self.router.chat_msg_observers) == 0 diff --git a/jupyter_ai_router/utils.py b/jupyter_ai_router/utils.py new file mode 100644 index 0000000..4d86729 --- /dev/null +++ b/jupyter_ai_router/utils.py @@ -0,0 +1,32 @@ +""" +Utility functions for message routing. +""" + + +from typing import Optional + + +def get_first_word(input_str: str) -> Optional[str]: + """ + Finds the first word in a given string, ignoring leading whitespace. + + Returns the first word, or None if there is no first word. + """ + start = 0 + + # Skip leading whitespace + while start < len(input_str) and input_str[start].isspace(): + start += 1 + + # Find end of first word + end = start + while end < len(input_str) and not input_str[end].isspace(): + end += 1 + + first_word = input_str[start:end] + return first_word if first_word else None + + +def is_persona(username: str) -> bool: + """Returns true if username belongs to a persona""" + return username.startswith("jupyter-ai-personas") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d118ed5..1a1f8e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,9 @@ classifiers = [ "Programming Language :: Python :: 3.13", ] dependencies = [ - "jupyter_server>=2.4.0,<3" + "jupyter_server>=2.4.0,<3", + "jupyterlab-chat>=0.17.0", + "jupyter-collaboration>=4.0.0" ] dynamic = ["version", "description", "authors", "urls", "keywords"] diff --git a/src/index.ts b/src/index.ts index 50f9ad3..1c0e31f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,7 +15,7 @@ const plugin: JupyterFrontEndPlugin = { activate: (app: JupyterFrontEnd) => { console.log('JupyterLab extension @jupyter-ai/router is activated!'); - requestAPI('get-example') + requestAPI('health') .then(data => { console.log(data); }) diff --git a/ui-tests/tests/jupyter_ai_router.spec.ts b/ui-tests/tests/jupyter_ai_router.spec.ts index 36f0b33..4f78f16 100644 --- a/ui-tests/tests/jupyter_ai_router.spec.ts +++ b/ui-tests/tests/jupyter_ai_router.spec.ts @@ -16,6 +16,8 @@ test('should emit an activation console message', async ({ page }) => { await page.goto(); expect( - logs.filter(s => s === 'JupyterLab extension @jupyter-ai/router is activated!') + logs.filter( + s => s === 'JupyterLab extension @jupyter-ai/router is activated!' + ) ).toHaveLength(1); });