Skip to content

Commit 1458af6

Browse files
authored
Introduce AI persona framework (#1324)
* add dev-reinstall script for convenience * add PersonaManager and JupyternautPersona * add PersonaAwareness to allow awareness on >1 persona * add DebugPersona to jupyter_ai_test * upgrade to Jupyter Chat v0.10.0 * automatically generate IDs for personas * add mention-based routing to PersonaManager * implement streaming replies in Jupyternaut persona * correctly identify AI messages in YChatHistory * fix mention routing to actually work * add logging capability to PersonaAwareness * update BasePersona.id to not depend on module path * improve logging in PersonaManager * log time elapsed in PersonaManager * simplify jupyternaut prompt template * pre-commit * add comment to ref jupyter-chat#212 * have _init_persona_manager() explicitly return None on exception * pre-commit * update docstrings on BasePersona * pre-commit * add return type annotation to forward_reply_stream() * remove unused return * add comment explaining _init_persona_classes() * add docstrings to persona_manager methods * remove redundant error logs * pre-commit * fix mypy errors * remove tests on chat handlers as they are superseded by personas * fix type annotation for Py39 compat * rename forward_reply_stream() to stream_message() * add send_message() method to BasePersona * pre-commit
1 parent 06338f1 commit 1458af6

File tree

16 files changed

+805
-197
lines changed

16 files changed

+805
-197
lines changed

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"clean:all": "lerna run clean:all",
3131
"dev": "jupyter lab --config playground/config.py",
3232
"dev-install": "lerna run dev-install --stream",
33+
"dev-reinstall": "jlpm dev-uninstall && jlpm dev-install",
3334
"dev-uninstall": "lerna run dev-uninstall --stream",
3435
"install-from-src": "lerna run install-from-src --stream",
3536
"lint": "jlpm && lerna run prettier && lerna run eslint",
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from jupyter_ai.personas.base_persona import BasePersona, PersonaDefaults
2+
from jupyterlab_chat.models import Message, NewMessage
3+
4+
5+
class DebugPersona(BasePersona):
6+
"""
7+
The Jupyternaut persona, the main persona provided by Jupyter AI.
8+
"""
9+
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
13+
@property
14+
def defaults(self):
15+
return PersonaDefaults(
16+
name="DebugPersona",
17+
avatar_path="/api/ai/static/jupyternaut.svg",
18+
description="A mock persona used for debugging in local dev environments.",
19+
system_prompt="...",
20+
)
21+
22+
async def process_message(self, message: Message):
23+
self.ychat.add_message(NewMessage(body="Hello!", sender=self.id))
24+
return

packages/jupyter-ai-test/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ test-provider-ask-learn-unsupported = "jupyter_ai_test.test_providers:TestProvid
3535
[project.entry-points."jupyter_ai.chat_handlers"]
3636
test-slash-command = "jupyter_ai_test.test_slash_commands:TestSlashCommand"
3737

38+
[project.entry-points."jupyter_ai.personas"]
39+
debug-persona = "jupyter_ai_test.debug_persona:DebugPersona"
40+
3841
[tool.hatch.build.hooks.version]
3942
path = "jupyter_ai_test/_version.py"
4043

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import re
33
import time
44
import types
5+
from asyncio import get_event_loop_policy
56
from functools import partial
6-
from typing import Dict
7+
from typing import TYPE_CHECKING, Dict, Optional
78

89
import traitlets
910
from dask.distributed import Client as DaskClient
@@ -12,7 +13,6 @@
1213
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
1314
from jupyter_events import EventLogger
1415
from jupyter_server.extension.application import ExtensionApp
15-
from jupyter_server.utils import url_path_join
1616
from jupyterlab_chat.models import Message
1717
from jupyterlab_chat.ychat import YChat
1818
from pycrdt import ArrayEvent
@@ -22,7 +22,6 @@
2222
from .chat_handlers.base import BaseChatHandler
2323
from .completions.handlers import DefaultInlineCompletionHandler
2424
from .config_manager import ConfigManager
25-
from .constants import BOT
2625
from .context_providers import BaseCommandContextProvider, FileContextProvider
2726
from .handlers import (
2827
ApiKeysHandler,
@@ -33,6 +32,10 @@
3332
SlashCommandsInfoHandler,
3433
)
3534
from .history import YChatHistory
35+
from .personas import PersonaManager
36+
37+
if TYPE_CHECKING:
38+
from asyncio import AbstractEventLoop
3639

3740
from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
3841
__version__ as jupyter_collaboration_version,
@@ -244,6 +247,13 @@ def initialize(self):
244247
schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat
245248
)
246249

250+
@property
251+
def event_loop(self) -> "AbstractEventLoop":
252+
"""
253+
Returns a reference to the asyncio event loop.
254+
"""
255+
return get_event_loop_policy().get_event_loop()
256+
247257
async def connect_chat(
248258
self, logger: EventLogger, schema_id: str, data: dict
249259
) -> None:
@@ -264,17 +274,19 @@ async def connect_chat(
264274
if ychat is None:
265275
return
266276

267-
# Add the bot user to the chat document awareness.
268-
BOT["avatar_url"] = url_path_join(
269-
self.settings.get("base_url", "/"), "api/ai/static/jupyternaut.svg"
270-
)
271-
if ychat.awareness is not None:
272-
ychat.awareness.set_local_state_field("user", BOT)
273-
274277
# initialize chat handlers for new chat
275278
self.chat_handlers_by_room[room_id] = self._init_chat_handlers(ychat)
276279

277-
callback = partial(self.on_change, room_id)
280+
# initialize persona manager
281+
persona_manager = self._init_persona_manager(ychat)
282+
if not persona_manager:
283+
self.log.error(
284+
"Jupyter AI was unable to initialize its AI personas. They are not available for use in chat until this error is resolved. "
285+
+ "Please verify your configuration and open a new issue on GitHub if this error persists."
286+
)
287+
return
288+
289+
callback = partial(self.on_change, room_id, persona_manager)
278290
ychat.ymessages.observe(callback)
279291

280292
async def get_chat(self, room_id: str) -> YChat:
@@ -301,21 +313,26 @@ async def get_chat(self, room_id: str) -> YChat:
301313
self.ychats_by_room[room_id] = document
302314
return document
303315

304-
def on_change(self, room_id: str, events: ArrayEvent) -> None:
316+
def on_change(
317+
self, room_id: str, persona_manager: PersonaManager, events: ArrayEvent
318+
) -> None:
305319
assert self.serverapp
306320

307321
for change in events.delta: # type:ignore[attr-defined]
308322
if not "insert" in change.keys():
309323
continue
310-
messages = change["insert"]
311-
for message_dict in messages:
312-
message = Message(**message_dict)
313-
if message.sender == BOT["username"] or message.raw_time:
314-
continue
315324

316-
self.serverapp.io_loop.asyncio_loop.create_task( # type:ignore[attr-defined]
317-
self.route_human_message(room_id, message)
318-
)
325+
# the "if not m['raw_time']" clause is necessary because every new
326+
# message triggers 2 events, one with `raw_time` set to `True` and
327+
# another with `raw_time` set to `False` milliseconds later.
328+
# we should explore fixing this quirk in Jupyter Chat.
329+
#
330+
# Ref: https://github.com/jupyterlab/jupyter-chat/issues/212
331+
new_messages = [
332+
Message(**m) for m in change["insert"] if not m.get("raw_time", False)
333+
]
334+
for new_message in new_messages:
335+
persona_manager.route_message(new_message)
319336

320337
async def route_human_message(self, room_id: str, message: Message):
321338
"""
@@ -400,18 +417,15 @@ def initialize_settings(self):
400417

401418
self.log.info(f"Registered {self.name} server extension")
402419

403-
# get reference to event loop
404-
# `asyncio.get_event_loop()` is deprecated in Python 3.11+, in favor of
405-
# the more readable `asyncio.get_event_loop_policy().get_event_loop()`.
406-
# it's easier to just reference the loop directly.
407-
loop = self.serverapp.io_loop.asyncio_loop
408-
self.settings["jai_event_loop"] = loop
420+
self.settings["jai_event_loop"] = self.event_loop
409421

410422
# We cannot instantiate the Dask client directly here because it
411423
# requires the event loop to be running on init. So instead we schedule
412424
# this as a task that is run as soon as the loop starts, and pass
413425
# consumers a Future that resolves to the Dask client when awaited.
414-
self.settings["dask_client_future"] = loop.create_task(self._get_dask_client())
426+
self.settings["dask_client_future"] = self.event_loop.create_task(
427+
self._get_dask_client()
428+
)
415429

416430
# Create empty context providers dict to be filled later.
417431
# This is created early to use as kwargs for chat handlers.
@@ -456,10 +470,7 @@ async def _stop_extension(self):
456470

457471
def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
458472
"""
459-
Initializes a set of chat handlers. May accept a YChat instance for
460-
collaborative chats.
461-
462-
TODO: Make `ychat` required once Jupyter Chat migration is complete.
473+
Initializes a set of chat handlers for a given `YChat` instance.
463474
"""
464475
assert self.serverapp
465476

@@ -606,3 +617,32 @@ def _init_context_providers(self):
606617
**context_providers_kwargs
607618
)
608619
self.log.info(f"Registered context provider `{context_provider.id}`.")
620+
621+
def _init_persona_manager(self, ychat: YChat) -> Optional[PersonaManager]:
622+
"""
623+
Initializes a `PersonaManager` instance scoped to a `YChat`.
624+
625+
This method should not raise an exception. Upon encountering an
626+
exception, this method will catch it, log it, and return `None`.
627+
"""
628+
persona_manager: Optional[PersonaManager]
629+
630+
try:
631+
config_manager = self.settings.get("jai_config_manager", None)
632+
assert config_manager and isinstance(config_manager, ConfigManager)
633+
634+
persona_manager = PersonaManager(
635+
ychat=ychat,
636+
config_manager=config_manager,
637+
event_loop=self.event_loop,
638+
log=self.log,
639+
)
640+
except Exception as e:
641+
# TODO: how to stop the extension when this fails
642+
# also why do uncaught exceptions produce an empty error log in Jupyter Server?
643+
self.log.error(
644+
f"Unable to initialize PersonaManager in YChat with ID '{ychat.get_id()}' due to an exception printed below."
645+
)
646+
self.log.exception(e)
647+
finally:
648+
return persona_manager

packages/jupyter-ai/jupyter_ai/history.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import List, Optional
22

3-
from jupyter_ai.constants import BOT
43
from jupyterlab_chat.models import Message as JChatMessage
54
from jupyterlab_chat.ychat import YChat
65
from langchain_core.chat_history import BaseChatMessageHistory
@@ -46,7 +45,7 @@ def _convert_to_langchain_messages(self, jchat_messages: List[JChatMessage]):
4645
"""
4746
messages: List[BaseMessage] = []
4847
for jchat_message in jchat_messages:
49-
if jchat_message.sender == BOT["username"]:
48+
if jchat_message.sender.startswith("jupyter-ai-personas::"):
5049
messages.append(AIMessage(content=jchat_message.body))
5150
else:
5251
messages.append(HumanMessage(content=jchat_message.body))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base_persona import BasePersona, PersonaDefaults
2+
from .persona_manager import PersonaManager

0 commit comments

Comments
 (0)