8
8
from langchain_core .prompts import PromptTemplate
9
9
10
10
from .base import BaseChatHandler , SlashCommandRoutingType
11
+ from .learn import LearnChatHandler , Retriever
11
12
12
13
PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
13
14
18
19
CONDENSE_PROMPT = PromptTemplate .from_template (PROMPT_TEMPLATE )
19
20
20
21
22
+ class CustomLearnException (Exception ):
23
+ """Exception raised when Jupyter AI's /ask command is used without the required /learn command."""
24
+
25
+ def __init__ (self ):
26
+ super ().__init__ (
27
+ "Jupyter AI's default /ask command requires the default /learn command. "
28
+ "If you are overriding /learn via the entry points API, be sure to also override or disable /ask."
29
+ )
30
+
31
+
21
32
class AskChatHandler (BaseChatHandler ):
22
33
"""Processes messages prefixed with /ask. This actor will
23
34
send the message as input to a RetrieverQA chain, that
@@ -33,12 +44,16 @@ class AskChatHandler(BaseChatHandler):
33
44
34
45
uses_llm = True
35
46
36
- def __init__ (self , retriever , * args , ** kwargs ):
47
+ def __init__ (self , * args , ** kwargs ):
37
48
super ().__init__ (* args , ** kwargs )
38
49
39
- self ._retriever = retriever
40
50
self .parser .prog = "/ask"
41
51
self .parser .add_argument ("query" , nargs = argparse .REMAINDER )
52
+ learn_chat_handler = self .chat_handlers .get ("/learn" )
53
+ if not isinstance (learn_chat_handler , LearnChatHandler ):
54
+ raise CustomLearnException ()
55
+
56
+ self ._retriever = Retriever (learn_chat_handler = learn_chat_handler )
42
57
43
58
def create_llm_chain (
44
59
self , provider : Type [BaseProvider ], provider_params : Dict [str , str ]
@@ -51,6 +66,7 @@ def create_llm_chain(
51
66
memory = ConversationBufferWindowMemory (
52
67
memory_key = "chat_history" , return_messages = True , k = 2
53
68
)
69
+
54
70
self .llm_chain = ConversationalRetrievalChain .from_llm (
55
71
self .llm ,
56
72
self ._retriever ,
0 commit comments