diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 2d49a299f..8a14470c9 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import asyncio import os import time from asyncio import get_event_loop_policy @@ -30,7 +32,7 @@ from .secrets.secrets_rest_api import SecretsRestAPI if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any from asyncio import AbstractEventLoop from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip @@ -216,14 +218,15 @@ def initialize(self, argv: Any = None) -> None: self.ychats_by_room: dict[str, YChat] = {} """Cache of YChat instances, indexed by room ID.""" - if self.serverapp is not None: - self.event_logger = self.serverapp.web_app.settings["event_logger"] - self.event_logger.add_listener( - schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat - ) + # Original direct event listener (commented out for router integration) + # if self.serverapp is not None: + # self.event_logger = self.serverapp.web_app.settings["event_logger"] + # self.event_logger.add_listener( + # schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat + # ) @property - def event_loop(self) -> "AbstractEventLoop": + def event_loop(self) -> AbstractEventLoop: """ Returns a reference to the asyncio event loop. """ @@ -347,6 +350,8 @@ def initialize_settings(self): # When a message's interrupt event is set, the response is halted. self.settings["jai_message_interrupted"] = {} + self.event_loop.create_task(self._setup_router_integration()) + # Log server extension startup time self.log.info(f"Registered {self.name} server extension") startup_time = round((time.time() - start) * 1000) @@ -377,14 +382,14 @@ async def _stop_extension(self): def _init_persona_manager( self, room_id: str, ychat: YChat - ) -> Optional[PersonaManager]: + ) -> PersonaManager | None: """ Initializes a `PersonaManager` instance scoped to a `YChat`. This method should not raise an exception. Upon encountering an exception, this method will catch it, log it, and return `None`. """ - persona_manager: Optional[PersonaManager] = None + persona_manager: PersonaManager | None = None try: config_manager = self.settings.get("jai_config_manager", None) @@ -428,6 +433,56 @@ def _init_persona_manager( finally: return persona_manager + async def _setup_router_integration(self) -> None: + """ + Set up integration with jupyter-ai-router if available. + This allows persona manager to work through the centralized MessageRouter. + """ + self.log.info("Waiting for the router to be ready") + + # Wait until the router field is available + while True: + router = self.serverapp.web_app.settings.get("jupyter-ai", {}).get("router") + if router is not None: + self.log.info("Router is ready, continuing with the integration") + break + await asyncio.sleep(0.1) # Check every 100ms + + try: + self.log.info("Found jupyter-ai-router, registering callbacks") + + # Register callback for new chat initialization + router.observe_chat_init(self._on_router_chat_init) + + # Store reference to router for later use + self.router = router + + except Exception as e: + self.log.error(f"Error setting up router integration: {e}") + + def _on_router_chat_init(self, room_id: str, ychat: YChat) -> None: + """ + Callback for when router detects a new chat initialization. + This replaces the direct event listener approach. + """ + self.log.info(f"Router detected new chat room: {room_id}") + + # Cache the YChat instance + self.ychats_by_room[room_id] = ychat + + # Initialize persona manager for this chat + persona_manager = self._init_persona_manager(room_id, ychat) + if not persona_manager: + self.log.error( + "Jupyter AI was unable to initialize its AI personas. They are not available for use in chat until this error is resolved. " + + "Please verify your configuration and open a new issue on GitHub if this error persists." + ) + return + + # Register persona manager callbacks with router + self.router.observe_chat_msg(room_id, persona_manager.on_chat_message) + self.router.observe_slash_cmd_msg(room_id, persona_manager.on_slash_cmd_message) + def _link_jupyter_server_extension(self, server_app: ServerApp): """Setup custom config needed by this extension.""" c = Config() diff --git a/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py b/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py index 0fb579cda..34580b066 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py +++ b/packages/jupyter-ai/jupyter_ai/personas/persona_manager.py @@ -359,6 +359,12 @@ def get_mentioned_personas(self, new_message: Message) -> list[BasePersona]: persona_list.append(self.personas[mentioned_id]) return persona_list + def on_chat_message(self, room_id: str, message: Message): + self.route_message(message) + + def on_slash_cmd_message(self, room_id: str, message: Message): + self.route_slash_command(message) + def route_message(self, new_message: Message): """ Method that routes an incoming message to the correct personas by