Skip to content

Commit 4923218

Browse files
authored
Added a react agent with persistent memory (#17)
* Added langchain, ChatLiteLLM * Added memory store, react agent * Added tools * Added version floor for aiosqlite * Added filesearch, shell tools, print tool messages * Fixed notebook tools, arbitrary outputs to chat. * Updated dependencies * Fixed shutdown
1 parent d39a35e commit 4923218

File tree

10 files changed

+2758
-55
lines changed

10 files changed

+2758
-55
lines changed

jupyter_ai_jupyternaut/jupyternaut/chat_models.py

Lines changed: 622 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 149 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,96 @@
1-
from typing import Any, Optional
2-
3-
from jupyterlab_chat.models import Message
4-
from litellm import acompletion
1+
import os
2+
from typing import Any, Callable
53

4+
import aiosqlite
65
from jupyter_ai_persona_manager import BasePersona, PersonaDefaults
7-
from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME
6+
from jupyter_core.paths import jupyter_data_dir
7+
from jupyterlab_chat.models import Message
8+
from langchain.agents import create_agent
9+
from langchain.agents.middleware import AgentMiddleware
10+
from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware
11+
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
12+
from langchain.messages import ToolMessage
13+
from langchain.tools.tool_node import ToolCallRequest
14+
from langchain_core.messages import ToolMessage
15+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
16+
from langgraph.types import Command
17+
18+
from .chat_models import ChatLiteLLM
819
from .prompt_template import (
920
JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE,
1021
JupyternautSystemPromptArgs,
1122
)
23+
from .toolkits.notebook import toolkit as nb_toolkit
24+
from .toolkits.jupyterlab import toolkit as jlab_toolkit
25+
26+
MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite")
27+
28+
29+
def format_tool_args_compact(args_dict, threshold=25):
30+
"""
31+
Create a more compact string representation of tool call args.
32+
Each key-value pair is on its own line for better readability.
33+
34+
Args:
35+
args_dict (dict): Dictionary of tool arguments
36+
threshold (int): Maximum number of lines before truncation (default: 25)
37+
38+
Returns:
39+
str: Formatted string representation of arguments
40+
"""
41+
if not args_dict:
42+
return "{}"
43+
44+
formatted_pairs = []
45+
46+
for key, value in args_dict.items():
47+
value_str = str(value)
48+
lines = value_str.split('\n')
49+
50+
if len(lines) <= threshold:
51+
if len(lines) == 1 and len(value_str) > 80:
52+
# Single long line - truncate
53+
truncated = value_str[:77] + "..."
54+
formatted_pairs.append(f" {key}: {truncated}")
55+
else:
56+
# Add indentation for multi-line values
57+
if len(lines) > 1:
58+
indented_value = '\n '.join([''] + lines)
59+
formatted_pairs.append(f" {key}:{indented_value}")
60+
else:
61+
formatted_pairs.append(f" {key}: {value_str}")
62+
else:
63+
# Truncate and add summary
64+
truncated_lines = lines[:threshold]
65+
remaining_lines = len(lines) - threshold
66+
indented_value = '\n '.join([''] + truncated_lines)
67+
formatted_pairs.append(f" {key}:{indented_value}\n [+{remaining_lines} more lines]")
68+
69+
return "{\n" + ",\n".join(formatted_pairs) + "\n}"
70+
71+
72+
class ToolMonitoringMiddleware(AgentMiddleware):
73+
def __init__(self, *, persona: BasePersona):
74+
self.stream_message = persona.stream_message
75+
self.log = persona.log
76+
77+
async def awrap_tool_call(
78+
self,
79+
request: ToolCallRequest,
80+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
81+
) -> ToolMessage | Command:
82+
args = format_tool_args_compact(request.tool_call['args'])
83+
self.log.info(f"{request.tool_call['name']}({args})")
84+
85+
try:
86+
result = await handler(request)
87+
self.log.info(f"{request.tool_call['name']} Done!")
88+
return result
89+
except Exception as e:
90+
self.log.info(f"{request.tool_call['name']} failed: {e}")
91+
return ToolMessage(
92+
tool_call_id=request.tool_call["id"], status="error", content=f"{e}"
93+
)
1294

1395

1496
class JupyternautPersona(BasePersona):
@@ -28,11 +110,45 @@ def defaults(self):
28110
system_prompt="...",
29111
)
30112

113+
async def get_memory_store(self):
114+
if not hasattr(self, "_memory_store"):
115+
conn = await aiosqlite.connect(MEMORY_STORE_PATH, check_same_thread=False)
116+
self._memory_store = AsyncSqliteSaver(conn)
117+
return self._memory_store
118+
119+
def get_tools(self):
120+
tools = nb_toolkit
121+
tools += jlab_toolkit
122+
return nb_toolkit
123+
124+
async def get_agent(self, model_id: str, model_args, system_prompt: str):
125+
model = ChatLiteLLM(**model_args, model_id=model_id, streaming=True)
126+
memory_store = await self.get_memory_store()
127+
128+
if not hasattr(self, "search_tool"):
129+
self.search_tool = FilesystemFileSearchMiddleware(
130+
root_path=self.parent.root_dir
131+
)
132+
if not hasattr(self, "shell_tool"):
133+
self.shell_tool = ShellToolMiddleware(workspace_root=self.parent.root_dir)
134+
if not hasattr(self, "tool_call_handler"):
135+
self.tool_call_handler = ToolMonitoringMiddleware(
136+
persona=self
137+
)
138+
139+
return create_agent(
140+
model,
141+
system_prompt=system_prompt,
142+
checkpointer=memory_store,
143+
tools=self.get_tools(), # notebook and jlab tools
144+
middleware=[self.shell_tool, self.tool_call_handler],
145+
)
146+
31147
async def process_message(self, message: Message) -> None:
32-
if not hasattr(self, 'config_manager'):
148+
if not hasattr(self, "config_manager"):
33149
self.send_message(
34150
"Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n\n",
35-
"Please make sure to first install that package in your environment & restart the server."
151+
"Please make sure to first install that package in your environment & restart the server.",
36152
)
37153
if not self.config_manager.chat_model:
38154
self.send_message(
@@ -43,65 +159,44 @@ async def process_message(self, message: Message) -> None:
43159

44160
model_id = self.config_manager.chat_model
45161
model_args = self.config_manager.chat_model_args
46-
context_as_messages = self.get_context_as_messages(model_id, message)
47-
response_aiter = await acompletion(
48-
**model_args,
49-
model=model_id,
50-
messages=[
51-
*context_as_messages,
52-
{
53-
"role": "user",
54-
"content": message.body,
55-
},
56-
],
57-
stream=True,
162+
system_prompt = self.get_system_prompt(model_id=model_id, message=message)
163+
agent = await self.get_agent(
164+
model_id=model_id, model_args=model_args, system_prompt=system_prompt
58165
)
59166

167+
async def create_aiter():
168+
async for token, metadata in agent.astream(
169+
{"messages": [{"role": "user", "content": message.body}]},
170+
{"configurable": {"thread_id": self.ychat.get_id()}},
171+
stream_mode="messages",
172+
):
173+
node = metadata["langgraph_node"]
174+
content_blocks = token.content_blocks
175+
if (
176+
node == "model"
177+
and content_blocks
178+
):
179+
if token.text:
180+
yield token.text
181+
182+
response_aiter = create_aiter()
60183
await self.stream_message(response_aiter)
61184

62-
def get_context_as_messages(
185+
def get_system_prompt(
63186
self, model_id: str, message: Message
64187
) -> list[dict[str, Any]]:
65188
"""
66-
Returns the current context, including attachments and recent messages,
67-
as a list of messages accepted by `litellm.acompletion()`.
189+
Returns the system prompt, including attachments as a string.
68190
"""
69191
system_msg_args = JupyternautSystemPromptArgs(
70192
model_id=model_id,
71193
persona_name=self.name,
72194
context=self.process_attachments(message),
73195
).model_dump()
74196

75-
system_msg = {
76-
"role": "system",
77-
"content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args),
78-
}
79-
80-
context_as_messages = [system_msg, *self._get_history_as_messages()]
81-
return context_as_messages
82-
83-
def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]:
84-
"""
85-
Returns the current history as a list of messages accepted by
86-
`litellm.acompletion()`.
87-
"""
88-
# TODO: consider bounding history based on message size (e.g. total
89-
# char/token count) instead of message count.
90-
all_messages = self.ychat.get_messages()
91-
92-
# gather last k * 2 messages and return
93-
# we exclude the last message since that is the human message just
94-
# submitted by a user.
95-
start_idx = 0 if k is None else -2 * k - 1
96-
recent_messages: list[Message] = all_messages[start_idx:-1]
97-
98-
history: list[dict[str, Any]] = []
99-
for msg in recent_messages:
100-
role = (
101-
"assistant"
102-
if msg.sender.startswith("jupyter-ai-personas::")
103-
else "system" if msg.sender == SYSTEM_USERNAME else "user"
104-
)
105-
history.append({"role": role, "content": msg.body})
197+
return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args)
106198

107-
return history
199+
def shutdown(self):
200+
if hasattr(self,"_memory_store"):
201+
self.parent.event_loop.create_task(self._memory_store.conn.close())
202+
super().shutdown()

jupyter_ai_jupyternaut/jupyternaut/prompt_template.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030
3131
- Example of a correct response: `You have \\(\\$80\\) remaining.`
3232
33+
If the user's request involves writing to a file, don't use fenced code blocks, write the content directly.
34+
35+
If the request requires using the add_cell or edit_cell to add code to a notebook code cell, don't use fenced code block.
36+
37+
If the request requires adding markdown to a notebook markdown cell, don't use markdown code block.
38+
39+
Don't echo contents back to user after reading files. Rather use that information to fulfill user's request.
40+
3341
You will receive any provided context and a relevant portion of the chat history.
3442
3543
The user's request is located at the last message. Please fulfill the user's request to the best of your ability.

jupyter_ai_jupyternaut/jupyternaut/toolkits/__init__.py

Whitespace-only changes.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Tools that provide code execution features"""
2+
3+
import asyncio
4+
import shlex
5+
from typing import Optional
6+
7+
8+
async def bash(command: str, timeout: Optional[int] = None) -> str:
9+
"""Executes a bash command and returns the result
10+
11+
Args:
12+
command: The bash command to execute
13+
timeout: Optional timeout in seconds
14+
15+
Returns:
16+
The command output (stdout and stderr combined)
17+
"""
18+
19+
proc = await asyncio.create_subprocess_exec(
20+
*shlex.split(command),
21+
stdout=asyncio.subprocess.PIPE,
22+
stderr=asyncio.subprocess.PIPE,
23+
)
24+
25+
try:
26+
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout)
27+
output = stdout.decode("utf-8")
28+
error = stderr.decode("utf-8")
29+
30+
if proc.returncode != 0:
31+
if error:
32+
return f"Error: {error}"
33+
return f"Command failed with exit code {proc.returncode}"
34+
35+
return output if output else "Command executed successfully with no output."
36+
except asyncio.TimeoutError:
37+
proc.kill()
38+
return f"Command timed out after {timeout} seconds"
39+
40+
41+
toolkit = [bash]

0 commit comments

Comments
 (0)