Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,6 @@ dmypy.json

# Yarn cache
.yarn/

# For local testing
playground/
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ router = self.serverapp.web_app.settings.get("jupyter-ai", {}).get("router")
def on_new_chat(room_id: str, ychat: YChat):
print(f"New chat connected: {room_id}")

def on_slash_command(room_id: str, message: Message):
print(f"Slash command in {room_id}: {message.body}")
def on_slash_command(room_id: str, command: str, message: Message):
print(f"Slash command '{command}' in {room_id}: {message.body}")

def on_regular_message(room_id: str, message: Message):`
def on_regular_message(room_id: str, message: Message):
print(f"Regular message in {room_id}: {message.body}")

# Register the callbacks
router.observe_chat_init(on_new_chat)
router.observe_slash_cmd_msg("room-id", on_slash_command)
router.observe_slash_cmd_msg("room-id", "help", on_slash_command) # Only /help commands
router.observe_chat_msg("room-id", on_regular_message)
```

Expand All @@ -40,9 +40,30 @@ router.observe_chat_msg("room-id", on_regular_message)
### Available Methods

- `observe_chat_init(callback)` - Called when new chat sessions are initialized with `(room_id, ychat)`
- `observe_slash_cmd_msg(room_id, callback)` - Called for messages starting with `/` in a specific room
- `observe_slash_cmd_msg(room_id, command_pattern, callback)` - Called for specific slash commands matching the pattern in a specific room
- `observe_chat_msg(room_id, callback)` - Called for regular (non-slash) messages in a specific room

### Command Pattern Matching

The `observe_slash_cmd_msg` method supports regex pattern matching:

```python
# Exact match: Only matches "/help"
router.observe_slash_cmd_msg("room-id", "help", callback)

# Regex pattern: Matches "/ai-generate", "/ai-review", etc.
router.observe_slash_cmd_msg("room-id", "ai-.*", callback)

# Regex with groups: Matches "/export-json", "/export-csv", "/export-xml"
router.observe_slash_cmd_msg("room-id", r"export-(json|csv|xml)", callback)
```

**Callback signature**: `callback(room_id: str, command: str, message: Message)`

- `room_id`: The chat room identifier
- `command`: The matched command without the leading slash (e.g., "help", "ai-generate")
- `message`: Message object with the command removed from the body (only arguments remain)

## Install

To install the extension, execute:
Expand Down
78 changes: 61 additions & 17 deletions jupyter_ai_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from typing import Any, Callable, Dict, List, TYPE_CHECKING
from functools import partial
import re
from dataclasses import replace
from jupyterlab_chat.models import Message
from pycrdt import ArrayEvent
from traitlets.config import LoggingConfigurable
Expand All @@ -19,6 +21,23 @@
from .utils import get_first_word


def matches_pattern(word: str, pattern: str) -> bool:
"""
Check if a word matches a regex pattern.

Args:
word: The word to match (e.g., "help", "ai-generate")
pattern: The regex pattern to match against (e.g., "help", "ai-.*", "export-(json|csv)")

Returns:
True if the word matches the pattern
"""
try:
return bool(re.match(f"^{pattern}$", word))
except re.error:
return False


class MessageRouter(LoggingConfigurable):
"""
Router that manages ychat message routing.
Expand All @@ -34,7 +53,7 @@ def __init__(self, *args, **kwargs):

# Callback lists
self.chat_init_observers: List[Callable[[str, "YChat"], Any]] = []
self.slash_cmd_observers: Dict[str, List[Callable[[str, Message], Any]]] = {}
self.slash_cmd_observers: Dict[str, Dict[str, List[Callable[[str, str, Message], Any]]]] = {}
self.chat_msg_observers: Dict[str, List[Callable[[str, Message], Any]]] = {}

# Active chat rooms
Expand All @@ -54,19 +73,28 @@ def observe_chat_init(self, callback: Callable[[str, "YChat"], Any]) -> None:
self.log.info("Registered new chat initialization callback")

def observe_slash_cmd_msg(
self, room_id: str, callback: Callable[[str, Message], Any]
self, room_id: str, command_pattern: str, callback: Callable[[str, str, Message], Any]
) -> None:
"""
Register a callback for when slash commands are received.
Register a callback for when specific slash commands are received.

Args:
callback: Function called with (room_id: str, message: Message) for slash commands
room_id: The chat room ID
command_pattern: Regex pattern to match commands (without leading slash).
Examples:
- Exact match: "help" matches "/help"
- Pattern match: "ai-.*" matches "/ai-generate", "/ai-review", etc.
- Multiple options: "export-(json|csv)" matches "/export-json", "/export-csv"
callback: Function called with (room_id: str, command: str, message: Message) for matching commands
"""
if room_id not in self.slash_cmd_observers:
self.slash_cmd_observers[room_id] = []
self.slash_cmd_observers[room_id] = {}

if command_pattern not in self.slash_cmd_observers[room_id]:
self.slash_cmd_observers[room_id][command_pattern] = []

self.slash_cmd_observers[room_id].append(callback)
self.log.info("Registered slash command callback")
self.slash_cmd_observers[room_id][command_pattern].append(callback)
self.log.info(f"Registered slash command callback for pattern: {command_pattern}")

def observe_chat_msg(
self, room_id: str, callback: Callable[[str, Message], Any]
Expand Down Expand Up @@ -157,10 +185,35 @@ def _route_message(self, room_id: str, message: Message) -> None:

# Check if it's a slash command
if first_word and first_word.startswith("/"):
self._notify_slash_cmd_observers(room_id, message)
# Extract command and create trimmed message
parts = message.body.split(None, 1) # Split into max 2 parts
command = parts[0] if parts else ""
trimmed_body = parts[1] if len(parts) > 1 else ""

# Create a copy of the message with trimmed body (command removed)
trimmed_message = replace(message, body=trimmed_body)

# Remove forward slash from command for cleaner API
clean_command = command[1:] if command.startswith("/") else command

# Route to slash command observers
self._notify_slash_cmd_observers(room_id, trimmed_message, clean_command)
else:
self._notify_msg_observers(room_id, message)


def _notify_slash_cmd_observers(self, room_id: str, message: Message, clean_command: str) -> None:
"""Notify observers registered for slash commands."""
room_observers = self.slash_cmd_observers.get(room_id, {})

for registered_pattern, callbacks in room_observers.items():
if matches_pattern(clean_command, registered_pattern):
for callback in callbacks:
try:
callback(room_id, clean_command, message)
except Exception as e:
self.log.error(f"Slash command observer error for pattern '{registered_pattern}': {e}")

def _notify_chat_init_observers(self, room_id: str, ychat: "YChat") -> None:
"""Notify all new chat observers."""
for callback in self.chat_init_observers:
Expand All @@ -169,15 +222,6 @@ def _notify_chat_init_observers(self, room_id: str, ychat: "YChat") -> None:
except Exception as e:
self.log.error(f"New chat observer error for {room_id}: {e}")

def _notify_slash_cmd_observers(self, room_id: str, message: Message) -> None:
"""Notify all slash command observers."""
callbacks = self.slash_cmd_observers.get(room_id, [])
for callback in callbacks:
try:
callback(room_id, message)
except Exception as e:
self.log.error(f"Slash command observer error for {room_id}: {e}")

def _notify_msg_observers(self, room_id: str, message: Message) -> None:
"""Notify all message observers."""
callbacks = self.chat_msg_observers.get(room_id, [])
Expand Down
Loading
Loading