Skip to content

Commit 2f573fd

Browse files
authored
Introduce AI personas (see PR #1324) (#1341)
add dev-reinstall script for convenience add PersonaManager and JupyternautPersona add PersonaAwareness to allow awareness on >1 persona add DebugPersona to jupyter_ai_test upgrade to Jupyter Chat v0.10.0 automatically generate IDs for personas add mention-based routing to PersonaManager implement streaming replies in Jupyternaut persona correctly identify AI messages in YChatHistory fix mention routing to actually work add logging capability to PersonaAwareness update BasePersona.id to not depend on module path improve logging in PersonaManager log time elapsed in PersonaManager simplify jupyternaut prompt template pre-commit add comment to ref jupyter-chat#212 have _init_persona_manager() explicitly return None on exception pre-commit update docstrings on BasePersona pre-commit add return type annotation to forward_reply_stream() remove unused return add comment explaining _init_persona_classes() add docstrings to persona_manager methods remove redundant error logs pre-commit fix mypy errors remove tests on chat handlers as they are superseded by personas fix type annotation for Py39 compat rename forward_reply_stream() to stream_message() add send_message() method to BasePersona pre-commit
1 parent 0a38783 commit 2f573fd

File tree

16 files changed

+805
-197
lines changed

16 files changed

+805
-197
lines changed

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"clean:all": "lerna run clean:all",
3131
"dev": "jupyter lab --config playground/config.py",
3232
"dev-install": "lerna run dev-install --stream",
33+
"dev-reinstall": "jlpm dev-uninstall && jlpm dev-install",
3334
"dev-uninstall": "lerna run dev-uninstall --stream",
3435
"install-from-src": "lerna run install-from-src --stream",
3536
"lint": "jlpm && lerna run prettier && lerna run eslint",
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from jupyter_ai.personas.base_persona import BasePersona, PersonaDefaults
2+
from jupyterlab_chat.models import Message, NewMessage
3+
4+
5+
class DebugPersona(BasePersona):
6+
"""
7+
The Jupyternaut persona, the main persona provided by Jupyter AI.
8+
"""
9+
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
13+
@property
14+
def defaults(self):
15+
return PersonaDefaults(
16+
name="DebugPersona",
17+
avatar_path="/api/ai/static/jupyternaut.svg",
18+
description="A mock persona used for debugging in local dev environments.",
19+
system_prompt="...",
20+
)
21+
22+
async def process_message(self, message: Message):
23+
self.ychat.add_message(NewMessage(body="Hello!", sender=self.id))
24+
return

packages/jupyter-ai-test/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ test-provider-ask-learn-unsupported = "jupyter_ai_test.test_providers:TestProvid
3535
[project.entry-points."jupyter_ai.chat_handlers"]
3636
test-slash-command = "jupyter_ai_test.test_slash_commands:TestSlashCommand"
3737

38+
[project.entry-points."jupyter_ai.personas"]
39+
debug-persona = "jupyter_ai_test.debug_persona:DebugPersona"
40+
3841
[tool.hatch.build.hooks.version]
3942
path = "jupyter_ai_test/_version.py"
4043

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import re
33
import time
44
import types
5+
from asyncio import get_event_loop_policy
56
from functools import partial
6-
from typing import Dict
7+
from typing import TYPE_CHECKING, Dict, Optional
78

89
import traitlets
910
from dask.distributed import Client as DaskClient
@@ -12,7 +13,6 @@
1213
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
1314
from jupyter_events import EventLogger
1415
from jupyter_server.extension.application import ExtensionApp
15-
from jupyter_server.utils import url_path_join
1616
from jupyterlab_chat.models import Message
1717
from jupyterlab_chat.ychat import YChat
1818
from pycrdt import ArrayEvent
@@ -22,7 +22,6 @@
2222
from .chat_handlers.base import BaseChatHandler
2323
from .completions.handlers import DefaultInlineCompletionHandler
2424
from .config_manager import ConfigManager
25-
from .constants import BOT
2625
from .context_providers import BaseCommandContextProvider, FileContextProvider
2726
from .handlers import (
2827
ApiKeysHandler,
@@ -33,6 +32,10 @@
3332
SlashCommandsInfoHandler,
3433
)
3534
from .history import YChatHistory
35+
from .personas import PersonaManager
36+
37+
if TYPE_CHECKING:
38+
from asyncio import AbstractEventLoop
3639

3740
from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
3841
__version__ as jupyter_collaboration_version,
@@ -244,6 +247,13 @@ def initialize(self):
244247
schema_id=JUPYTER_COLLABORATION_EVENTS_URI, listener=self.connect_chat
245248
)
246249

250+
@property
251+
def event_loop(self) -> "AbstractEventLoop":
252+
"""
253+
Returns a reference to the asyncio event loop.
254+
"""
255+
return get_event_loop_policy().get_event_loop()
256+
247257
async def connect_chat(
248258
self, logger: EventLogger, schema_id: str, data: dict
249259
) -> None:
@@ -264,17 +274,19 @@ async def connect_chat(
264274
if ychat is None:
265275
return
266276

267-
# Add the bot user to the chat document awareness.
268-
BOT["avatar_url"] = url_path_join(
269-
self.settings.get("base_url", "/"), "api/ai/static/jupyternaut.svg"
270-
)
271-
if ychat.awareness is not None:
272-
ychat.awareness.set_local_state_field("user", BOT)
273-
274277
# initialize chat handlers for new chat
275278
self.chat_handlers_by_room[room_id] = self._init_chat_handlers(ychat)
276279

277-
callback = partial(self.on_change, room_id)
280+
# initialize persona manager
281+
persona_manager = self._init_persona_manager(ychat)
282+
if not persona_manager:
283+
self.log.error(
284+
"Jupyter AI was unable to initialize its AI personas. They are not available for use in chat until this error is resolved. "
285+
+ "Please verify your configuration and open a new issue on GitHub if this error persists."
286+
)
287+
return
288+
289+
callback = partial(self.on_change, room_id, persona_manager)
278290
ychat.ymessages.observe(callback)
279291

280292
async def get_chat(self, room_id: str) -> YChat:
@@ -301,21 +313,26 @@ async def get_chat(self, room_id: str) -> YChat:
301313
self.ychats_by_room[room_id] = document
302314
return document
303315

304-
def on_change(self, room_id: str, events: ArrayEvent) -> None:
316+
def on_change(
317+
self, room_id: str, persona_manager: PersonaManager, events: ArrayEvent
318+
) -> None:
305319
assert self.serverapp
306320

307321
for change in events.delta: # type:ignore[attr-defined]
308322
if not "insert" in change.keys():
309323
continue
310-
messages = change["insert"]
311-
for message_dict in messages:
312-
message = Message(**message_dict)
313-
if message.sender == BOT["username"] or message.raw_time:
314-
continue
315324

316-
self.serverapp.io_loop.asyncio_loop.create_task( # type:ignore[attr-defined]
317-
self.route_human_message(room_id, message)
318-
)
325+
# the "if not m['raw_time']" clause is necessary because every new
326+
# message triggers 2 events, one with `raw_time` set to `True` and
327+
# another with `raw_time` set to `False` milliseconds later.
328+
# we should explore fixing this quirk in Jupyter Chat.
329+
#
330+
# Ref: https://github.com/jupyterlab/jupyter-chat/issues/212
331+
new_messages = [
332+
Message(**m) for m in change["insert"] if not m.get("raw_time", False)
333+
]
334+
for new_message in new_messages:
335+
persona_manager.route_message(new_message)
319336

320337
async def route_human_message(self, room_id: str, message: Message):
321338
"""
@@ -400,18 +417,15 @@ def initialize_settings(self):
400417

401418
self.log.info(f"Registered {self.name} server extension")
402419

403-
# get reference to event loop
404-
# `asyncio.get_event_loop()` is deprecated in Python 3.11+, in favor of
405-
# the more readable `asyncio.get_event_loop_policy().get_event_loop()`.
406-
# it's easier to just reference the loop directly.
407-
loop = self.serverapp.io_loop.asyncio_loop
408-
self.settings["jai_event_loop"] = loop
420+
self.settings["jai_event_loop"] = self.event_loop
409421

410422
# We cannot instantiate the Dask client directly here because it
411423
# requires the event loop to be running on init. So instead we schedule
412424
# this as a task that is run as soon as the loop starts, and pass
413425
# consumers a Future that resolves to the Dask client when awaited.
414-
self.settings["dask_client_future"] = loop.create_task(self._get_dask_client())
426+
self.settings["dask_client_future"] = self.event_loop.create_task(
427+
self._get_dask_client()
428+
)
415429

416430
# Create empty context providers dict to be filled later.
417431
# This is created early to use as kwargs for chat handlers.
@@ -456,10 +470,7 @@ async def _stop_extension(self):
456470

457471
def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
458472
"""
459-
Initializes a set of chat handlers. May accept a YChat instance for
460-
collaborative chats.
461-
462-
TODO: Make `ychat` required once Jupyter Chat migration is complete.
473+
Initializes a set of chat handlers for a given `YChat` instance.
463474
"""
464475
assert self.serverapp
465476

@@ -606,3 +617,32 @@ def _init_context_providers(self):
606617
**context_providers_kwargs
607618
)
608619
self.log.info(f"Registered context provider `{context_provider.id}`.")
620+
621+
def _init_persona_manager(self, ychat: YChat) -> Optional[PersonaManager]:
622+
"""
623+
Initializes a `PersonaManager` instance scoped to a `YChat`.
624+
625+
This method should not raise an exception. Upon encountering an
626+
exception, this method will catch it, log it, and return `None`.
627+
"""
628+
persona_manager: Optional[PersonaManager]
629+
630+
try:
631+
config_manager = self.settings.get("jai_config_manager", None)
632+
assert config_manager and isinstance(config_manager, ConfigManager)
633+
634+
persona_manager = PersonaManager(
635+
ychat=ychat,
636+
config_manager=config_manager,
637+
event_loop=self.event_loop,
638+
log=self.log,
639+
)
640+
except Exception as e:
641+
# TODO: how to stop the extension when this fails
642+
# also why do uncaught exceptions produce an empty error log in Jupyter Server?
643+
self.log.error(
644+
f"Unable to initialize PersonaManager in YChat with ID '{ychat.get_id()}' due to an exception printed below."
645+
)
646+
self.log.exception(e)
647+
finally:
648+
return persona_manager

packages/jupyter-ai/jupyter_ai/history.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import List, Optional
22

3-
from jupyter_ai.constants import BOT
43
from jupyterlab_chat.models import Message as JChatMessage
54
from jupyterlab_chat.ychat import YChat
65
from langchain_core.chat_history import BaseChatMessageHistory
@@ -46,7 +45,7 @@ def _convert_to_langchain_messages(self, jchat_messages: List[JChatMessage]):
4645
"""
4746
messages: List[BaseMessage] = []
4847
for jchat_message in jchat_messages:
49-
if jchat_message.sender == BOT["username"]:
48+
if jchat_message.sender.startswith("jupyter-ai-personas::"):
5049
messages.append(AIMessage(content=jchat_message.body))
5150
else:
5251
messages.append(HumanMessage(content=jchat_message.body))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base_persona import BasePersona, PersonaDefaults
2+
from .persona_manager import PersonaManager

0 commit comments

Comments
 (0)