Skip to content

Commit 055d7b8

Browse files
authored
Refactor Chat Handlers to Simplify Initialization (#1257)
* simplify-entrypoints-loading * fix-lint * fix-tests * add-retriever-typing * remove-retriever-from-base * fix-circular-import(ydoc-import) * fix-tests * fix-type-check-failure * refactor-retriever-init
1 parent 7b4586e commit 055d7b8

File tree

7 files changed

+34
-16
lines changed

7 files changed

+34
-16
lines changed

packages/jupyter-ai/jupyter_ai/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# The following import is to make sure jupyter_ydoc is imported before
2+
# jupyterlab_chat, otherwise it leads to circular import because of the
3+
# YChat relying on YBaseDoc, and jupyter_ydoc registering YChat from the entry point.
4+
import jupyter_ydoc
5+
16
# expose jupyter_ai_magics ipython extension
27
# DO NOT REMOVE.
38
from jupyter_ai_magics import load_ipython_extension, unload_ipython_extension

packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
# The following import is to make sure jupyter_ydoc is imported before
2-
# jupyterlab_chat, otherwise it leads to circular import because of the
3-
# YChat relying on YBaseDoc, and jupyter_ydoc registering YChat from the entry point.
4-
import jupyter_ydoc
5-
61
from .ask import AskChatHandler
72
from .base import BaseChatHandler, SlashCommandRoutingType
83
from .default import DefaultChatHandler

packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from langchain_core.prompts import PromptTemplate
99

1010
from .base import BaseChatHandler, SlashCommandRoutingType
11+
from .learn import LearnChatHandler, Retriever
1112

1213
PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
1314
@@ -18,6 +19,16 @@
1819
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)
1920

2021

22+
class CustomLearnException(Exception):
23+
"""Exception raised when Jupyter AI's /ask command is used without the required /learn command."""
24+
25+
def __init__(self):
26+
super().__init__(
27+
"Jupyter AI's default /ask command requires the default /learn command. "
28+
"If you are overriding /learn via the entry points API, be sure to also override or disable /ask."
29+
)
30+
31+
2132
class AskChatHandler(BaseChatHandler):
2233
"""Processes messages prefixed with /ask. This actor will
2334
send the message as input to a RetrieverQA chain, that
@@ -33,12 +44,16 @@ class AskChatHandler(BaseChatHandler):
3344

3445
uses_llm = True
3546

36-
def __init__(self, retriever, *args, **kwargs):
47+
def __init__(self, *args, **kwargs):
3748
super().__init__(*args, **kwargs)
3849

39-
self._retriever = retriever
4050
self.parser.prog = "/ask"
4151
self.parser.add_argument("query", nargs=argparse.REMAINDER)
52+
learn_chat_handler = self.chat_handlers.get("/learn")
53+
if not isinstance(learn_chat_handler, LearnChatHandler):
54+
raise CustomLearnException()
55+
56+
self._retriever = Retriever(learn_chat_handler=learn_chat_handler)
4257

4358
def create_llm_chain(
4459
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
@@ -51,6 +66,7 @@ def create_llm_chain(
5166
memory = ConversationBufferWindowMemory(
5267
memory_key="chat_history", return_messages=True, k=2
5368
)
69+
5470
self.llm_chain = ConversationalRetrievalChain.from_llm(
5571
self.llm,
5672
self._retriever,

packages/jupyter-ai/jupyter_ai/chat_handlers/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import contextlib
44
import os
55
import traceback
6+
from pathlib import Path
67
from typing import (
78
TYPE_CHECKING,
9+
Any,
810
Awaitable,
911
ClassVar,
1012
Dict,
@@ -22,6 +24,7 @@
2224
from jupyter_ai_magics.providers import BaseProvider
2325
from jupyterlab_chat.models import Message, NewMessage, User
2426
from jupyterlab_chat.ychat import YChat
27+
from langchain.schema import BaseRetriever
2528
from langchain_core.messages import AIMessageChunk
2629
from langchain_core.runnables import Runnable
2730
from langchain_core.runnables.config import RunnableConfig
@@ -141,6 +144,7 @@ def __init__(
141144
context_providers: Dict[str, "BaseCommandContextProvider"],
142145
message_interrupted: Dict[str, asyncio.Event],
143146
ychat: YChat,
147+
log_dir: Optional[str],
144148
):
145149
self.log = log
146150
self.config_manager = config_manager
@@ -162,6 +166,7 @@ def __init__(
162166
self.context_providers = context_providers
163167
self.message_interrupted = message_interrupted
164168
self.ychat = ychat
169+
self.log_dir = Path(log_dir) if log_dir else None
165170

166171
self.llm: Optional[BaseProvider] = None
167172
self.llm_params: Optional[dict] = None

packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,8 @@ class GenerateChatHandler(BaseChatHandler):
253253

254254
uses_llm = True
255255

256-
def __init__(self, log_dir: Optional[str], *args, **kwargs):
256+
def __init__(self, *args, **kwargs):
257257
super().__init__(*args, **kwargs)
258-
self.log_dir = Path(log_dir) if log_dir else None
259258
self.llm: Optional[BaseProvider] = None
260259

261260
def create_llm_chain(

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import traitlets
99
from dask.distributed import Client as DaskClient
1010
from importlib_metadata import entry_points
11-
from jupyter_ai.chat_handlers.learn import Retriever
1211
from jupyter_ai_magics import BaseProvider, JupyternautPersona
1312
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
1413
from jupyter_events import EventLogger
@@ -477,15 +476,13 @@ def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
477476
"context_providers": self.settings["jai_context_providers"],
478477
"message_interrupted": self.settings["jai_message_interrupted"],
479478
"ychat": ychat,
479+
"log_dir": self.error_logs_dir,
480480
}
481+
481482
default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
482-
generate_chat_handler = GenerateChatHandler(
483-
**chat_handler_kwargs,
484-
log_dir=self.error_logs_dir,
485-
)
483+
generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs)
486484
learn_chat_handler = LearnChatHandler(**chat_handler_kwargs)
487-
retriever = Retriever(learn_chat_handler=learn_chat_handler)
488-
ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever)
485+
ask_chat_handler = AskChatHandler(**chat_handler_kwargs)
489486

490487
chat_handlers["default"] = default_chat_handler
491488
chat_handlers["/ask"] = ask_chat_handler

packages/jupyter-ai/jupyter_ai/tests/test_handlers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, lm_provider=None, lm_provider_params=None):
6767
message_interrupted={},
6868
llm_chat_memory=self.ychat_history,
6969
ychat=self.ychat,
70+
log_dir="",
7071
)
7172

7273
@property

0 commit comments

Comments
 (0)