Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
622 changes: 622 additions & 0 deletions jupyter_ai_jupyternaut/jupyternaut/chat_models.py

Large diffs are not rendered by default.

203 changes: 149 additions & 54 deletions jupyter_ai_jupyternaut/jupyternaut/jupyternaut.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,96 @@
from typing import Any, Optional

from jupyterlab_chat.models import Message
from litellm import acompletion
import os
from typing import Any, Callable

import aiosqlite
from jupyter_ai_persona_manager import BasePersona, PersonaDefaults
from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME
from jupyter_core.paths import jupyter_data_dir
from jupyterlab_chat.models import Message
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
from langchain.messages import ToolMessage
from langchain.tools.tool_node import ToolCallRequest
from langchain_core.messages import ToolMessage
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.types import Command

from .chat_models import ChatLiteLLM
from .prompt_template import (
JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE,
JupyternautSystemPromptArgs,
)
from .toolkits.notebook import toolkit as nb_toolkit
from .toolkits.jupyterlab import toolkit as jlab_toolkit

MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite")


def format_tool_args_compact(args_dict, threshold=25):
"""
Create a more compact string representation of tool call args.
Each key-value pair is on its own line for better readability.

Args:
args_dict (dict): Dictionary of tool arguments
threshold (int): Maximum number of lines before truncation (default: 25)

Returns:
str: Formatted string representation of arguments
"""
if not args_dict:
return "{}"

formatted_pairs = []

for key, value in args_dict.items():
value_str = str(value)
lines = value_str.split('\n')

if len(lines) <= threshold:
if len(lines) == 1 and len(value_str) > 80:
# Single long line - truncate
truncated = value_str[:77] + "..."
formatted_pairs.append(f" {key}: {truncated}")
else:
# Add indentation for multi-line values
if len(lines) > 1:
indented_value = '\n '.join([''] + lines)
formatted_pairs.append(f" {key}:{indented_value}")
else:
formatted_pairs.append(f" {key}: {value_str}")
else:
# Truncate and add summary
truncated_lines = lines[:threshold]
remaining_lines = len(lines) - threshold
indented_value = '\n '.join([''] + truncated_lines)
formatted_pairs.append(f" {key}:{indented_value}\n [+{remaining_lines} more lines]")

return "{\n" + ",\n".join(formatted_pairs) + "\n}"


class ToolMonitoringMiddleware(AgentMiddleware):
def __init__(self, *, persona: BasePersona):
self.stream_message = persona.stream_message
self.log = persona.log

async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
args = format_tool_args_compact(request.tool_call['args'])
self.log.info(f"{request.tool_call['name']}({args})")

try:
result = await handler(request)
self.log.info(f"{request.tool_call['name']} Done!")
return result
except Exception as e:
self.log.info(f"{request.tool_call['name']} failed: {e}")
return ToolMessage(
tool_call_id=request.tool_call["id"], status="error", content=f"{e}"
)


class JupyternautPersona(BasePersona):
Expand All @@ -28,11 +110,45 @@ def defaults(self):
system_prompt="...",
)

async def get_memory_store(self):
if not hasattr(self, "_memory_store"):
conn = await aiosqlite.connect(MEMORY_STORE_PATH, check_same_thread=False)
self._memory_store = AsyncSqliteSaver(conn)
return self._memory_store

def get_tools(self):
tools = nb_toolkit
tools += jlab_toolkit
return nb_toolkit

async def get_agent(self, model_id: str, model_args, system_prompt: str):
model = ChatLiteLLM(**model_args, model_id=model_id, streaming=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the correct parameter should be model=model_id (model instead of model_id), according to the ChatLiteLLM attribute.

When testing this PR, the backend is complaining about missing OpenAi API key. Trying to debug it, it seems that the model setup in ChatLiteLLM is always the default one, gpt-3.5-turbo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #19 to fix it.

memory_store = await self.get_memory_store()

if not hasattr(self, "search_tool"):
self.search_tool = FilesystemFileSearchMiddleware(
root_path=self.parent.root_dir
)
if not hasattr(self, "shell_tool"):
self.shell_tool = ShellToolMiddleware(workspace_root=self.parent.root_dir)
if not hasattr(self, "tool_call_handler"):
self.tool_call_handler = ToolMonitoringMiddleware(
persona=self
)

return create_agent(
model,
system_prompt=system_prompt,
checkpointer=memory_store,
tools=self.get_tools(), # notebook and jlab tools
middleware=[self.shell_tool, self.tool_call_handler],
)

async def process_message(self, message: Message) -> None:
if not hasattr(self, 'config_manager'):
if not hasattr(self, "config_manager"):
self.send_message(
"Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n\n",
"Please make sure to first install that package in your environment & restart the server."
"Please make sure to first install that package in your environment & restart the server.",
)
if not self.config_manager.chat_model:
self.send_message(
Expand All @@ -43,65 +159,44 @@ async def process_message(self, message: Message) -> None:

model_id = self.config_manager.chat_model
model_args = self.config_manager.chat_model_args
context_as_messages = self.get_context_as_messages(model_id, message)
response_aiter = await acompletion(
**model_args,
model=model_id,
messages=[
*context_as_messages,
{
"role": "user",
"content": message.body,
},
],
stream=True,
system_prompt = self.get_system_prompt(model_id=model_id, message=message)
agent = await self.get_agent(
model_id=model_id, model_args=model_args, system_prompt=system_prompt
)

async def create_aiter():
async for token, metadata in agent.astream(
{"messages": [{"role": "user", "content": message.body}]},
{"configurable": {"thread_id": self.ychat.get_id()}},
stream_mode="messages",
Comment on lines +169 to +171
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(non-blocking) Since we're only adding to the SQLite checkpointer when this persona is called, does this mean that Jupyternaut will lack context on messages not routed to Jupyternaut?

For example, consider the following chat:

User: Hello, what is the Riemann hypothesis?
<SomePersona>: <complete nonsense>
User: @Jupyternaut can you try to answer this?
# does Jupyternaut have context on the 2 preceding messages?

This is fine for now, just checking to see if I understand the current behavior.

Copy link
Contributor Author

@3coins 3coins Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, we need a shared memory manager (or a store) in persona manager or base persona that enables personas to write messages for shared context along with an API to load the shared context.

):
node = metadata["langgraph_node"]
content_blocks = token.content_blocks
if (
node == "model"
and content_blocks
):
if token.text:
yield token.text

response_aiter = create_aiter()
await self.stream_message(response_aiter)

def get_context_as_messages(
def get_system_prompt(
self, model_id: str, message: Message
) -> list[dict[str, Any]]:
"""
Returns the current context, including attachments and recent messages,
as a list of messages accepted by `litellm.acompletion()`.
Returns the system prompt, including attachments as a string.
"""
system_msg_args = JupyternautSystemPromptArgs(
model_id=model_id,
persona_name=self.name,
context=self.process_attachments(message),
).model_dump()

system_msg = {
"role": "system",
"content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args),
}

context_as_messages = [system_msg, *self._get_history_as_messages()]
return context_as_messages

def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]:
"""
Returns the current history as a list of messages accepted by
`litellm.acompletion()`.
"""
# TODO: consider bounding history based on message size (e.g. total
# char/token count) instead of message count.
all_messages = self.ychat.get_messages()

# gather last k * 2 messages and return
# we exclude the last message since that is the human message just
# submitted by a user.
start_idx = 0 if k is None else -2 * k - 1
recent_messages: list[Message] = all_messages[start_idx:-1]

history: list[dict[str, Any]] = []
for msg in recent_messages:
role = (
"assistant"
if msg.sender.startswith("jupyter-ai-personas::")
else "system" if msg.sender == SYSTEM_USERNAME else "user"
)
history.append({"role": role, "content": msg.body})
return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args)

return history
def shutdown(self):
if hasattr(self,"_memory_store"):
self.parent.event_loop.create_task(self._memory_store.conn.close())
super().shutdown()
8 changes: 8 additions & 0 deletions jupyter_ai_jupyternaut/jupyternaut/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@

- Example of a correct response: `You have \\(\\$80\\) remaining.`

If the user's request involves writing to a file, don't use fenced code blocks, write the content directly.

If the request requires using the add_cell or edit_cell to add code to a notebook code cell, don't use fenced code block.

If the request requires adding markdown to a notebook markdown cell, don't use markdown code block.

Don't echo contents back to user after reading files. Rather use that information to fulfill user's request.

You will receive any provided context and a relevant portion of the chat history.

The user's request is located at the last message. Please fulfill the user's request to the best of your ability.
Expand Down
Empty file.
41 changes: 41 additions & 0 deletions jupyter_ai_jupyternaut/jupyternaut/toolkits/code_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Tools that provide code execution features"""

import asyncio
import shlex
from typing import Optional


async def bash(command: str, timeout: Optional[int] = None) -> str:
"""Executes a bash command and returns the result

Args:
command: The bash command to execute
timeout: Optional timeout in seconds

Returns:
The command output (stdout and stderr combined)
"""

proc = await asyncio.create_subprocess_exec(
*shlex.split(command),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout)
output = stdout.decode("utf-8")
error = stderr.decode("utf-8")

if proc.returncode != 0:
if error:
return f"Error: {error}"
return f"Command failed with exit code {proc.returncode}"

return output if output else "Command executed successfully with no output."
except asyncio.TimeoutError:
proc.kill()
return f"Command timed out after {timeout} seconds"


toolkit = [bash]
Loading
Loading