1- from typing import Any , Optional
2-
3- from jupyterlab_chat .models import Message
4- from litellm import acompletion
1+ import os
2+ from typing import Any , Callable
53
4+ import aiosqlite
65from jupyter_ai_persona_manager import BasePersona , PersonaDefaults
7- from jupyter_ai_persona_manager .persona_manager import SYSTEM_USERNAME
6+ from jupyter_core .paths import jupyter_data_dir
7+ from jupyterlab_chat .models import Message
8+ from langchain .agents import create_agent
9+ from langchain .agents .middleware import AgentMiddleware
10+ from langchain .agents .middleware .file_search import FilesystemFileSearchMiddleware
11+ from langchain .agents .middleware .shell_tool import ShellToolMiddleware
12+ from langchain .messages import ToolMessage
13+ from langchain .tools .tool_node import ToolCallRequest
14+ from langchain_core .messages import ToolMessage
15+ from langgraph .checkpoint .sqlite .aio import AsyncSqliteSaver
16+ from langgraph .types import Command
17+
18+ from .chat_models import ChatLiteLLM
819from .prompt_template import (
920 JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE ,
1021 JupyternautSystemPromptArgs ,
1122)
23+ from .toolkits .notebook import toolkit as nb_toolkit
24+ from .toolkits .jupyterlab import toolkit as jlab_toolkit
25+
26+ MEMORY_STORE_PATH = os .path .join (jupyter_data_dir (), "jupyter_ai" , "memory.sqlite" )
27+
28+
29+ def format_tool_args_compact (args_dict , threshold = 25 ):
30+ """
31+ Create a more compact string representation of tool call args.
32+ Each key-value pair is on its own line for better readability.
33+
34+ Args:
35+ args_dict (dict): Dictionary of tool arguments
36+ threshold (int): Maximum number of lines before truncation (default: 25)
37+
38+ Returns:
39+ str: Formatted string representation of arguments
40+ """
41+ if not args_dict :
42+ return "{}"
43+
44+ formatted_pairs = []
45+
46+ for key , value in args_dict .items ():
47+ value_str = str (value )
48+ lines = value_str .split ('\n ' )
49+
50+ if len (lines ) <= threshold :
51+ if len (lines ) == 1 and len (value_str ) > 80 :
52+ # Single long line - truncate
53+ truncated = value_str [:77 ] + "..."
54+ formatted_pairs .append (f" { key } : { truncated } " )
55+ else :
56+ # Add indentation for multi-line values
57+ if len (lines ) > 1 :
58+ indented_value = '\n ' .join (['' ] + lines )
59+ formatted_pairs .append (f" { key } :{ indented_value } " )
60+ else :
61+ formatted_pairs .append (f" { key } : { value_str } " )
62+ else :
63+ # Truncate and add summary
64+ truncated_lines = lines [:threshold ]
65+ remaining_lines = len (lines ) - threshold
66+ indented_value = '\n ' .join (['' ] + truncated_lines )
67+ formatted_pairs .append (f" { key } :{ indented_value } \n [+{ remaining_lines } more lines]" )
68+
69+ return "{\n " + ",\n " .join (formatted_pairs ) + "\n }"
70+
71+
72+ class ToolMonitoringMiddleware (AgentMiddleware ):
73+ def __init__ (self , * , persona : BasePersona ):
74+ self .stream_message = persona .stream_message
75+ self .log = persona .log
76+
77+ async def awrap_tool_call (
78+ self ,
79+ request : ToolCallRequest ,
80+ handler : Callable [[ToolCallRequest ], ToolMessage | Command ],
81+ ) -> ToolMessage | Command :
82+ args = format_tool_args_compact (request .tool_call ['args' ])
83+ self .log .info (f"{ request .tool_call ['name' ]} ({ args } )" )
84+
85+ try :
86+ result = await handler (request )
87+ self .log .info (f"{ request .tool_call ['name' ]} Done!" )
88+ return result
89+ except Exception as e :
90+ self .log .info (f"{ request .tool_call ['name' ]} failed: { e } " )
91+ return ToolMessage (
92+ tool_call_id = request .tool_call ["id" ], status = "error" , content = f"{ e } "
93+ )
1294
1395
1496class JupyternautPersona (BasePersona ):
@@ -28,11 +110,45 @@ def defaults(self):
28110 system_prompt = "..." ,
29111 )
30112
113+ async def get_memory_store (self ):
114+ if not hasattr (self , "_memory_store" ):
115+ conn = await aiosqlite .connect (MEMORY_STORE_PATH , check_same_thread = False )
116+ self ._memory_store = AsyncSqliteSaver (conn )
117+ return self ._memory_store
118+
119+ def get_tools (self ):
120+ tools = nb_toolkit
121+ tools += jlab_toolkit
122+ return nb_toolkit
123+
124+ async def get_agent (self , model_id : str , model_args , system_prompt : str ):
125+ model = ChatLiteLLM (** model_args , model_id = model_id , streaming = True )
126+ memory_store = await self .get_memory_store ()
127+
128+ if not hasattr (self , "search_tool" ):
129+ self .search_tool = FilesystemFileSearchMiddleware (
130+ root_path = self .parent .root_dir
131+ )
132+ if not hasattr (self , "shell_tool" ):
133+ self .shell_tool = ShellToolMiddleware (workspace_root = self .parent .root_dir )
134+ if not hasattr (self , "tool_call_handler" ):
135+ self .tool_call_handler = ToolMonitoringMiddleware (
136+ persona = self
137+ )
138+
139+ return create_agent (
140+ model ,
141+ system_prompt = system_prompt ,
142+ checkpointer = memory_store ,
143+ tools = self .get_tools (), # notebook and jlab tools
144+ middleware = [self .shell_tool , self .tool_call_handler ],
145+ )
146+
31147 async def process_message (self , message : Message ) -> None :
32- if not hasattr (self , ' config_manager' ):
148+ if not hasattr (self , " config_manager" ):
33149 self .send_message (
34150 "Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n \n " ,
35- "Please make sure to first install that package in your environment & restart the server."
151+ "Please make sure to first install that package in your environment & restart the server." ,
36152 )
37153 if not self .config_manager .chat_model :
38154 self .send_message (
@@ -43,65 +159,44 @@ async def process_message(self, message: Message) -> None:
43159
44160 model_id = self .config_manager .chat_model
45161 model_args = self .config_manager .chat_model_args
46- context_as_messages = self .get_context_as_messages (model_id , message )
47- response_aiter = await acompletion (
48- ** model_args ,
49- model = model_id ,
50- messages = [
51- * context_as_messages ,
52- {
53- "role" : "user" ,
54- "content" : message .body ,
55- },
56- ],
57- stream = True ,
162+ system_prompt = self .get_system_prompt (model_id = model_id , message = message )
163+ agent = await self .get_agent (
164+ model_id = model_id , model_args = model_args , system_prompt = system_prompt
58165 )
59166
167+ async def create_aiter ():
168+ async for token , metadata in agent .astream (
169+ {"messages" : [{"role" : "user" , "content" : message .body }]},
170+ {"configurable" : {"thread_id" : self .ychat .get_id ()}},
171+ stream_mode = "messages" ,
172+ ):
173+ node = metadata ["langgraph_node" ]
174+ content_blocks = token .content_blocks
175+ if (
176+ node == "model"
177+ and content_blocks
178+ ):
179+ if token .text :
180+ yield token .text
181+
182+ response_aiter = create_aiter ()
60183 await self .stream_message (response_aiter )
61184
62- def get_context_as_messages (
185+ def get_system_prompt (
63186 self , model_id : str , message : Message
64187 ) -> list [dict [str , Any ]]:
65188 """
66- Returns the current context, including attachments and recent messages,
67- as a list of messages accepted by `litellm.acompletion()`.
189+ Returns the system prompt, including attachments as a string.
68190 """
69191 system_msg_args = JupyternautSystemPromptArgs (
70192 model_id = model_id ,
71193 persona_name = self .name ,
72194 context = self .process_attachments (message ),
73195 ).model_dump ()
74196
75- system_msg = {
76- "role" : "system" ,
77- "content" : JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE .render (** system_msg_args ),
78- }
79-
80- context_as_messages = [system_msg , * self ._get_history_as_messages ()]
81- return context_as_messages
82-
83- def _get_history_as_messages (self , k : Optional [int ] = 2 ) -> list [dict [str , Any ]]:
84- """
85- Returns the current history as a list of messages accepted by
86- `litellm.acompletion()`.
87- """
88- # TODO: consider bounding history based on message size (e.g. total
89- # char/token count) instead of message count.
90- all_messages = self .ychat .get_messages ()
91-
92- # gather last k * 2 messages and return
93- # we exclude the last message since that is the human message just
94- # submitted by a user.
95- start_idx = 0 if k is None else - 2 * k - 1
96- recent_messages : list [Message ] = all_messages [start_idx :- 1 ]
97-
98- history : list [dict [str , Any ]] = []
99- for msg in recent_messages :
100- role = (
101- "assistant"
102- if msg .sender .startswith ("jupyter-ai-personas::" )
103- else "system" if msg .sender == SYSTEM_USERNAME else "user"
104- )
105- history .append ({"role" : role , "content" : msg .body })
197+ return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE .render (** system_msg_args )
106198
107- return history
199+ def shutdown (self ):
200+ if hasattr (self ,"_memory_store" ):
201+ self .parent .event_loop .create_task (self ._memory_store .conn .close ())
202+ super ().shutdown ()
0 commit comments