11import os
2- import aiosqlite
3- from typing import Any , Optional
2+ from typing import Any , Callable
43
4+ import aiosqlite
55from jupyter_ai_persona_manager import BasePersona , PersonaDefaults
6- from jupyter_ai_persona_manager .persona_manager import SYSTEM_USERNAME
76from jupyter_core .paths import jupyter_data_dir
87from jupyterlab_chat .models import Message
98from langchain .agents import create_agent
9+ from langchain .agents .middleware import AgentMiddleware , wrap_tool_call
10+ from langchain .agents .middleware .file_search import FilesystemFileSearchMiddleware
11+ from langchain .agents .middleware .shell_tool import ShellToolMiddleware
12+ from langchain .messages import ToolMessage
13+ from langchain .tools .tool_node import ToolCallRequest
14+ from langchain_core .messages import ToolMessage
1015from langgraph .checkpoint .sqlite .aio import AsyncSqliteSaver
16+ from langgraph .types import Command
1117
1218from .chat_models import ChatLiteLLM
1319from .prompt_template import (
1420 JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE ,
1521 JupyternautSystemPromptArgs ,
1622)
17-
18- from .toolkits .code_execution import toolkit as exec_toolkit
19- from .toolkits .filesystem import toolkit as fs_toolkit
2023from .toolkits .notebook import toolkit as nb_toolkit
2124
22-
2325MEMORY_STORE_PATH = os .path .join (jupyter_data_dir (), "jupyter_ai" , "memory.sqlite" )
2426
2527
28+ class ToolMonitoringMiddleware (AgentMiddleware ):
29+ def __init__ (self , * , stream_message : BasePersona .stream_message ):
30+ self .stream_message = stream_message
31+
32+ async def awrap_tool_call (
33+ self ,
34+ request : ToolCallRequest ,
35+ handler : Callable [[ToolCallRequest ], ToolMessage | Command ],
36+ ) -> ToolMessage | Command :
37+ running_tool_msg = f"Running **{ request .tool_call ['name' ]} ** with *{ request .tool_call ['args' ]} *"
38+ await self .stream_message (self ._aiter (running_tool_msg ))
39+ try :
40+ result = await handler (request )
41+ if hasattr (result , "content" ) and result .content != "null" :
42+ completed_tool_msg = str (result .content )[:100 ]
43+ else :
44+ completed_tool_msg = "Done!"
45+ await self .stream_message (self ._aiter (completed_tool_msg ))
46+ return result
47+ except Exception as e :
48+ await self .stream_message (f"**{ request .tool_call ['name' ]} ** failed: { e } " )
49+ return ToolMessage (
50+ tool_call_id = request .tool_call ["id" ], status = "error" , content = f"{ e } "
51+ )
52+
53+ async def _aiter (self , message : str ):
54+ yield message
55+
56+
2657class JupyternautPersona (BasePersona ):
2758 """
2859 The Jupyternaut persona, the main persona provided by Jupyter AI.
@@ -43,24 +74,48 @@ def defaults(self):
4374 async def get_memory_store (self ):
4475 if not hasattr (self , "_memory_store" ):
4576 conn = await aiosqlite .connect (MEMORY_STORE_PATH , check_same_thread = False )
46- self ._memory_store = AsyncSqliteSaver (conn )
77+ self ._memory_store = AsyncSqliteSaver (conn )
4778 return self ._memory_store
48-
79+
4980 def get_tools (self ):
5081 tools = []
5182 tools += nb_toolkit
52- tools += fs_toolkit
5383
5484 return tools
5585
5686 async def get_agent (self , model_id : str , model_args , system_prompt : str ):
5787 model = ChatLiteLLM (** model_args , model_id = model_id , streaming = True )
5888 memory_store = await self .get_memory_store ()
89+
90+ @wrap_tool_call
91+ def handle_tool_errors (request , handler ):
92+ """Handle tool execution errors with custom messages."""
93+ try :
94+ return handler (request )
95+ except Exception as e :
96+ # Return a custom error message to the model
97+ return ToolMessage (
98+ content = f"Error calling tool: ({ str (e )} )" ,
99+ tool_call_id = request .tool_call ["id" ],
100+ )
101+
102+ if not hasattr (self , "search_tool" ):
103+ self .search_tool = FilesystemFileSearchMiddleware (
104+ root_path = self .parent .root_dir
105+ )
106+ if not hasattr (self , "shell_tool" ):
107+ self .shell_tool = ShellToolMiddleware (workspace_root = self .parent .root_dir )
108+ if not hasattr (self , "tool_call_handler" ):
109+ self .tool_call_handler = ToolMonitoringMiddleware (
110+ stream_message = self .stream_message
111+ )
112+
59113 return create_agent (
60- model ,
61- system_prompt = system_prompt ,
114+ model ,
115+ system_prompt = system_prompt ,
62116 checkpointer = memory_store ,
63- tools = self .get_tools ()
117+ tools = self .get_tools (),
118+ middleware = [self .search_tool , self .shell_tool , self .tool_call_handler ],
64119 )
65120
66121 async def process_message (self , message : Message ) -> None :
@@ -80,19 +135,21 @@ async def process_message(self, message: Message) -> None:
80135 model_args = self .config_manager .chat_model_args
81136 system_prompt = self .get_system_prompt (model_id = model_id , message = message )
82137 agent = await self .get_agent (
83- model_id = model_id ,
84- model_args = model_args ,
85- system_prompt = system_prompt
138+ model_id = model_id , model_args = model_args , system_prompt = system_prompt
86139 )
87140
88141 async def create_aiter ():
89- async for chunk , metadata in agent .astream (
142+ async for chunk , _ in agent .astream (
90143 {"messages" : [{"role" : "user" , "content" : message .body }]},
91144 {"configurable" : {"thread_id" : self .ychat .get_id ()}},
92145 stream_mode = "messages" ,
93146 ):
94- if chunk .content :
95- yield chunk .content
147+ if (
148+ hasattr (chunk , "content" )
149+ and (content := chunk .content )
150+ and content != "null"
151+ ):
152+ yield content
96153
97154 response_aiter = create_aiter ()
98155 await self .stream_message (response_aiter )
0 commit comments