Skip to content

Commit 649b775

Browse files
committed
Updates slash cmd to observers to filter on specific commands
1 parent a768d35 commit 649b775

File tree

3 files changed

+332
-21
lines changed

3 files changed

+332
-21
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,6 @@ dmypy.json
123123

124124
# Yarn cache
125125
.yarn/
126+
127+
# For local testing
128+
playground/

jupyter_ai_router/router.py

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from typing import Any, Callable, Dict, List, TYPE_CHECKING
1111
from functools import partial
12+
import re
13+
from dataclasses import replace
1214
from jupyterlab_chat.models import Message
1315
from pycrdt import ArrayEvent
1416
from traitlets.config import LoggingConfigurable
@@ -34,7 +36,7 @@ def __init__(self, *args, **kwargs):
3436

3537
# Callback lists
3638
self.chat_init_observers: List[Callable[[str, "YChat"], Any]] = []
37-
self.slash_cmd_observers: Dict[str, List[Callable[[str, Message], Any]]] = {}
39+
self.slash_cmd_observers: Dict[str, Dict[str, List[Callable[[str, str, Message], Any]]]] = {}
3840
self.chat_msg_observers: Dict[str, List[Callable[[str, Message], Any]]] = {}
3941

4042
# Active chat rooms
@@ -54,19 +56,27 @@ def observe_chat_init(self, callback: Callable[[str, "YChat"], Any]) -> None:
5456
self.log.info("Registered new chat initialization callback")
5557

5658
def observe_slash_cmd_msg(
57-
self, room_id: str, callback: Callable[[str, Message], Any]
59+
self, room_id: str, command_pattern: str, callback: Callable[[str, str, Message], Any]
5860
) -> None:
5961
"""
60-
Register a callback for when slash commands are received.
62+
Register a callback for when specific slash commands are received.
6163
6264
Args:
63-
callback: Function called with (room_id: str, message: Message) for slash commands
65+
room_id: The chat room ID
66+
command_pattern: Command pattern to match (without leading slash). Supports:
67+
- Exact match: "help" matches "/help"
68+
- Wildcard: "ai-*" matches "/ai-generate", "/ai-review", etc.
69+
- Regex: Any valid Python regex pattern like "export-(json|csv)"
70+
callback: Function called with (room_id: str, command: str, message: Message) for matching commands
6471
"""
6572
if room_id not in self.slash_cmd_observers:
66-
self.slash_cmd_observers[room_id] = []
73+
self.slash_cmd_observers[room_id] = {}
74+
75+
if command_pattern not in self.slash_cmd_observers[room_id]:
76+
self.slash_cmd_observers[room_id][command_pattern] = []
6777

68-
self.slash_cmd_observers[room_id].append(callback)
69-
self.log.info("Registered slash command callback")
78+
self.slash_cmd_observers[room_id][command_pattern].append(callback)
79+
self.log.info(f"Registered slash command callback for pattern: {command_pattern}")
7080

7181
def observe_chat_msg(
7282
self, room_id: str, callback: Callable[[str, Message], Any]
@@ -157,10 +167,73 @@ def _route_message(self, room_id: str, message: Message) -> None:
157167

158168
# Check if it's a slash command
159169
if first_word and first_word.startswith("/"):
160-
self._notify_slash_cmd_observers(room_id, message)
170+
# Extract command and create trimmed message
171+
parts = message.body.split(None, 1) # Split into max 2 parts
172+
command = parts[0] if parts else ""
173+
trimmed_body = parts[1] if len(parts) > 1 else ""
174+
175+
# Create a copy of the message with trimmed body (command removed)
176+
trimmed_message = replace(message, body=trimmed_body)
177+
178+
# Remove forward slash from command for cleaner API
179+
clean_command = command[1:] if command.startswith("/") else command
180+
181+
# Route to slash command observers
182+
self._notify_slash_cmd_observers(room_id, trimmed_message, command, clean_command)
161183
else:
162184
self._notify_msg_observers(room_id, message)
163185

186+
def _command_matches(self, command: str, pattern: str) -> bool:
187+
"""
188+
Check if a command matches a pattern.
189+
190+
Args:
191+
command: The actual command with slash (e.g., "/help")
192+
pattern: The pattern to match against without slash (e.g., "help", "ai-*", regex)
193+
194+
Returns:
195+
True if the command matches the pattern
196+
"""
197+
# Convert pattern to include slash for matching
198+
# Pattern "help" should match command "/help"
199+
if not pattern.startswith("/"):
200+
full_pattern = "/" + pattern
201+
else:
202+
# Handle case where pattern accidentally includes slash
203+
full_pattern = pattern
204+
205+
# Exact match
206+
if command == full_pattern:
207+
return True
208+
209+
# Wildcard pattern (convert to regex)
210+
if "*" in full_pattern:
211+
# Escape special regex chars except *, then convert * to .*
212+
escaped_pattern = re.escape(full_pattern).replace(r"\*", ".*")
213+
regex_pattern = f"^{escaped_pattern}$"
214+
try:
215+
return bool(re.match(regex_pattern, command))
216+
except re.error:
217+
return False
218+
219+
# Try as regex pattern (add slash if not present)
220+
try:
221+
return bool(re.match(full_pattern, command))
222+
except re.error:
223+
return False
224+
225+
def _notify_slash_cmd_observers(self, room_id: str, message: Message, original_command: str, clean_command: str) -> None:
226+
"""Notify observers registered for slash commands."""
227+
room_observers = self.slash_cmd_observers.get(room_id, {})
228+
229+
for registered_pattern, callbacks in room_observers.items():
230+
if self._command_matches(original_command, registered_pattern):
231+
for callback in callbacks:
232+
try:
233+
callback(room_id, clean_command, message)
234+
except Exception as e:
235+
self.log.error(f"Slash command observer error for pattern '{registered_pattern}': {e}")
236+
164237
def _notify_chat_init_observers(self, room_id: str, ychat: "YChat") -> None:
165238
"""Notify all new chat observers."""
166239
for callback in self.chat_init_observers:
@@ -169,15 +242,6 @@ def _notify_chat_init_observers(self, room_id: str, ychat: "YChat") -> None:
169242
except Exception as e:
170243
self.log.error(f"New chat observer error for {room_id}: {e}")
171244

172-
def _notify_slash_cmd_observers(self, room_id: str, message: Message) -> None:
173-
"""Notify all slash command observers."""
174-
callbacks = self.slash_cmd_observers.get(room_id, [])
175-
for callback in callbacks:
176-
try:
177-
callback(room_id, message)
178-
except Exception as e:
179-
self.log.error(f"Slash command observer error for {room_id}: {e}")
180-
181245
def _notify_msg_observers(self, room_id: str, message: Message) -> None:
182246
"""Notify all message observers."""
183247
callbacks = self.chat_msg_observers.get(room_id, [])

0 commit comments

Comments
 (0)