Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 64 additions & 9 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations

import asyncio
import os
import time
from asyncio import get_event_loop_policy
Expand Down Expand Up @@ -30,7 +32,7 @@
from .secrets.secrets_rest_api import SecretsRestAPI

if TYPE_CHECKING:
from typing import Any, Optional
from typing import Any
from asyncio import AbstractEventLoop

from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
Expand Down Expand Up @@ -216,14 +218,15 @@ def initialize(self, argv: Any = None) -> None:
self.ychats_by_room: dict[str, YChat] = {}
"""Cache of YChat instances, indexed by room ID."""

if self.serverapp is not None:
self.event_logger = self.serverapp.web_app.settings["event_logger"]
self.event_logger.add_listener(
schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat
)
# Original direct event listener (commented out for router integration)
# if self.serverapp is not None:
# self.event_logger = self.serverapp.web_app.settings["event_logger"]
# self.event_logger.add_listener(
# schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat
# )

@property
def event_loop(self) -> "AbstractEventLoop":
def event_loop(self) -> AbstractEventLoop:
"""
Returns a reference to the asyncio event loop.
"""
Expand Down Expand Up @@ -347,6 +350,8 @@ def initialize_settings(self):
# When a message's interrupt event is set, the response is halted.
self.settings["jai_message_interrupted"] = {}

self.event_loop.create_task(self._setup_router_integration())

# Log server extension startup time
self.log.info(f"Registered {self.name} server extension")
startup_time = round((time.time() - start) * 1000)
Expand Down Expand Up @@ -377,14 +382,14 @@ async def _stop_extension(self):

def _init_persona_manager(
self, room_id: str, ychat: YChat
) -> Optional[PersonaManager]:
) -> PersonaManager | None:
"""
Initializes a `PersonaManager` instance scoped to a `YChat`.

This method should not raise an exception. Upon encountering an
exception, this method will catch it, log it, and return `None`.
"""
persona_manager: Optional[PersonaManager] = None
persona_manager: PersonaManager | None = None

try:
config_manager = self.settings.get("jai_config_manager", None)
Expand Down Expand Up @@ -428,6 +433,56 @@ def _init_persona_manager(
finally:
return persona_manager

async def _setup_router_integration(self) -> None:
"""
Set up integration with jupyter-ai-router if available.
This allows persona manager to work through the centralized MessageRouter.
"""
self.log.info("Waiting for the router to be ready")

# Wait until the router field is available
while True:
router = self.serverapp.web_app.settings.get("jupyter-ai", {}).get("router")
if router is not None:
self.log.info("Router is ready, continuing with the integration")
break
await asyncio.sleep(0.1) # Check every 100ms

try:
self.log.info("Found jupyter-ai-router, registering callbacks")

# Register callback for new chat initialization
router.observe_chat_init(self._on_router_chat_init)

# Store reference to router for later use
self.router = router

except Exception as e:
self.log.error(f"Error setting up router integration: {e}")

def _on_router_chat_init(self, room_id: str, ychat: YChat) -> None:
"""
Callback for when router detects a new chat initialization.
This replaces the direct event listener approach.
"""
self.log.info(f"Router detected new chat room: {room_id}")

# Cache the YChat instance
self.ychats_by_room[room_id] = ychat

# Initialize persona manager for this chat
persona_manager = self._init_persona_manager(room_id, ychat)
if not persona_manager:
self.log.error(
"Jupyter AI was unable to initialize its AI personas. They are not available for use in chat until this error is resolved. "
+ "Please verify your configuration and open a new issue on GitHub if this error persists."
)
return

# Register persona manager callbacks with router
self.router.observe_chat_msg(room_id, persona_manager.on_chat_message)
self.router.observe_slash_cmd_msg(room_id, persona_manager.on_slash_cmd_message)

def _link_jupyter_server_extension(self, server_app: ServerApp):
"""Setup custom config needed by this extension."""
c = Config()
Expand Down
6 changes: 6 additions & 0 deletions packages/jupyter-ai/jupyter_ai/personas/persona_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ def get_mentioned_personas(self, new_message: Message) -> list[BasePersona]:
persona_list.append(self.personas[mentioned_id])
return persona_list

def on_chat_message(self, room_id: str, message: Message):
self.route_message(message)

def on_slash_cmd_message(self, room_id: str, message: Message):
self.route_slash_command(message)

def route_message(self, new_message: Message):
"""
Method that routes an incoming message to the correct personas by
Expand Down
Loading