Skip to content

Commit 9189ea5

Browse files
committed
Added filesearch, shell tools, print tool messages
1 parent 10e1146 commit 9189ea5

File tree

3 files changed

+83
-23
lines changed

3 files changed

+83
-23
lines changed

jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,59 @@
11
import os
2-
import aiosqlite
3-
from typing import Any, Optional
2+
from typing import Any, Callable
43

4+
import aiosqlite
55
from jupyter_ai_persona_manager import BasePersona, PersonaDefaults
6-
from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME
76
from jupyter_core.paths import jupyter_data_dir
87
from jupyterlab_chat.models import Message
98
from langchain.agents import create_agent
9+
from langchain.agents.middleware import AgentMiddleware, wrap_tool_call
10+
from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware
11+
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
12+
from langchain.messages import ToolMessage
13+
from langchain.tools.tool_node import ToolCallRequest
14+
from langchain_core.messages import ToolMessage
1015
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
16+
from langgraph.types import Command
1117

1218
from .chat_models import ChatLiteLLM
1319
from .prompt_template import (
1420
JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE,
1521
JupyternautSystemPromptArgs,
1622
)
17-
18-
from .toolkits.code_execution import toolkit as exec_toolkit
19-
from .toolkits.filesystem import toolkit as fs_toolkit
2023
from .toolkits.notebook import toolkit as nb_toolkit
2124

22-
2325
MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite")
2426

2527

28+
class ToolMonitoringMiddleware(AgentMiddleware):
29+
def __init__(self, *, stream_message: BasePersona.stream_message):
30+
self.stream_message = stream_message
31+
32+
async def awrap_tool_call(
33+
self,
34+
request: ToolCallRequest,
35+
handler: Callable[[ToolCallRequest], ToolMessage | Command],
36+
) -> ToolMessage | Command:
37+
running_tool_msg = f"Running **{request.tool_call['name']}** with *{request.tool_call['args']}*"
38+
await self.stream_message(self._aiter(running_tool_msg))
39+
try:
40+
result = await handler(request)
41+
if hasattr(result, "content") and result.content != "null":
42+
completed_tool_msg = str(result.content)[:100]
43+
else:
44+
completed_tool_msg = "Done!"
45+
await self.stream_message(self._aiter(completed_tool_msg))
46+
return result
47+
except Exception as e:
48+
await self.stream_message(f"**{request.tool_call['name']}** failed: {e}")
49+
return ToolMessage(
50+
tool_call_id=request.tool_call["id"], status="error", content=f"{e}"
51+
)
52+
53+
async def _aiter(self, message: str):
54+
yield message
55+
56+
2657
class JupyternautPersona(BasePersona):
2758
"""
2859
The Jupyternaut persona, the main persona provided by Jupyter AI.
@@ -43,24 +74,48 @@ def defaults(self):
4374
async def get_memory_store(self):
4475
if not hasattr(self, "_memory_store"):
4576
conn = await aiosqlite.connect(MEMORY_STORE_PATH, check_same_thread=False)
46-
self._memory_store = AsyncSqliteSaver(conn)
77+
self._memory_store = AsyncSqliteSaver(conn)
4778
return self._memory_store
48-
79+
4980
def get_tools(self):
5081
tools = []
5182
tools += nb_toolkit
52-
tools += fs_toolkit
5383

5484
return tools
5585

5686
async def get_agent(self, model_id: str, model_args, system_prompt: str):
5787
model = ChatLiteLLM(**model_args, model_id=model_id, streaming=True)
5888
memory_store = await self.get_memory_store()
89+
90+
@wrap_tool_call
91+
def handle_tool_errors(request, handler):
92+
"""Handle tool execution errors with custom messages."""
93+
try:
94+
return handler(request)
95+
except Exception as e:
96+
# Return a custom error message to the model
97+
return ToolMessage(
98+
content=f"Error calling tool: ({str(e)})",
99+
tool_call_id=request.tool_call["id"],
100+
)
101+
102+
if not hasattr(self, "search_tool"):
103+
self.search_tool = FilesystemFileSearchMiddleware(
104+
root_path=self.parent.root_dir
105+
)
106+
if not hasattr(self, "shell_tool"):
107+
self.shell_tool = ShellToolMiddleware(workspace_root=self.parent.root_dir)
108+
if not hasattr(self, "tool_call_handler"):
109+
self.tool_call_handler = ToolMonitoringMiddleware(
110+
stream_message=self.stream_message
111+
)
112+
59113
return create_agent(
60-
model,
61-
system_prompt=system_prompt,
114+
model,
115+
system_prompt=system_prompt,
62116
checkpointer=memory_store,
63-
tools=self.get_tools()
117+
tools=self.get_tools(),
118+
middleware=[self.search_tool, self.shell_tool, self.tool_call_handler],
64119
)
65120

66121
async def process_message(self, message: Message) -> None:
@@ -80,19 +135,21 @@ async def process_message(self, message: Message) -> None:
80135
model_args = self.config_manager.chat_model_args
81136
system_prompt = self.get_system_prompt(model_id=model_id, message=message)
82137
agent = await self.get_agent(
83-
model_id=model_id,
84-
model_args=model_args,
85-
system_prompt=system_prompt
138+
model_id=model_id, model_args=model_args, system_prompt=system_prompt
86139
)
87140

88141
async def create_aiter():
89-
async for chunk, metadata in agent.astream(
142+
async for chunk, _ in agent.astream(
90143
{"messages": [{"role": "user", "content": message.body}]},
91144
{"configurable": {"thread_id": self.ychat.get_id()}},
92145
stream_mode="messages",
93146
):
94-
if chunk.content:
95-
yield chunk.content
147+
if (
148+
hasattr(chunk, "content")
149+
and (content := chunk.content)
150+
and content != "null"
151+
):
152+
yield content
96153

97154
response_aiter = create_aiter()
98155
await self.stream_message(response_aiter)

jupyter_ai_jupyternaut/jupyternaut/prompt_template.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
3131
- Example of a correct response: `You have \\(\\$80\\) remaining.`
3232
33+
When analyzing files, notebooks, or any file contents:
34+
- Do NOT echo back or repeat the full contents of files in your response
35+
- Instead, provide analysis, summaries, or specific insights about the code/content
36+
- Only quote small, relevant excerpts when necessary to illustrate a point
37+
- Focus on answering the user's specific questions rather than displaying file contents
38+
3339
You will receive any provided context and a relevant portion of the chat history.
3440
3541
The user's request is located at the last message. Please fulfill the user's request to the best of your ability.

jupyter_ai_jupyternaut/jupyternaut/toolkits/notebook.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -803,10 +803,7 @@ def _safe_set_cursor(
803803
async def edit_cell(file_path: str, cell_id: str, content: str) -> None:
804804
"""Edits the content of a notebook cell with the specified ID
805805
806-
This function modifies the content of a cell in a Jupyter notebook. It first attempts to use
807-
the in-memory YDoc representation if the notebook is currently active. If the
808-
notebook is not active, it falls back to using the filesystem to read, modify,
809-
and write the notebook file directly using nbformat.
806+
This function modifies the content of a cell in a Jupyter notebook.
810807
811808
Args:
812809
file_path:

0 commit comments

Comments
 (0)