-
Notifications
You must be signed in to change notification settings - Fork 4
Added a react agent with persistent memory #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
19ba1a6
50394d2
0b46786
10e1146
9189ea5
26afbae
f751405
0065306
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,96 @@ | ||
| from typing import Any, Optional | ||
|
|
||
| from jupyterlab_chat.models import Message | ||
| from litellm import acompletion | ||
| import os | ||
| from typing import Any, Callable | ||
|
|
||
| import aiosqlite | ||
| from jupyter_ai_persona_manager import BasePersona, PersonaDefaults | ||
| from jupyter_ai_persona_manager.persona_manager import SYSTEM_USERNAME | ||
| from jupyter_core.paths import jupyter_data_dir | ||
| from jupyterlab_chat.models import Message | ||
| from langchain.agents import create_agent | ||
| from langchain.agents.middleware import AgentMiddleware | ||
| from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware | ||
| from langchain.agents.middleware.shell_tool import ShellToolMiddleware | ||
| from langchain.messages import ToolMessage | ||
| from langchain.tools.tool_node import ToolCallRequest | ||
| from langchain_core.messages import ToolMessage | ||
| from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver | ||
| from langgraph.types import Command | ||
|
|
||
| from .chat_models import ChatLiteLLM | ||
| from .prompt_template import ( | ||
| JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, | ||
| JupyternautSystemPromptArgs, | ||
| ) | ||
| from .toolkits.notebook import toolkit as nb_toolkit | ||
| from .toolkits.jupyterlab import toolkit as jlab_toolkit | ||
|
|
||
| MEMORY_STORE_PATH = os.path.join(jupyter_data_dir(), "jupyter_ai", "memory.sqlite") | ||
|
|
||
|
|
||
| def format_tool_args_compact(args_dict, threshold=25): | ||
| """ | ||
| Create a more compact string representation of tool call args. | ||
| Each key-value pair is on its own line for better readability. | ||
|
|
||
| Args: | ||
| args_dict (dict): Dictionary of tool arguments | ||
| threshold (int): Maximum number of lines before truncation (default: 25) | ||
|
|
||
| Returns: | ||
| str: Formatted string representation of arguments | ||
| """ | ||
| if not args_dict: | ||
| return "{}" | ||
|
|
||
| formatted_pairs = [] | ||
|
|
||
| for key, value in args_dict.items(): | ||
| value_str = str(value) | ||
| lines = value_str.split('\n') | ||
|
|
||
| if len(lines) <= threshold: | ||
| if len(lines) == 1 and len(value_str) > 80: | ||
| # Single long line - truncate | ||
| truncated = value_str[:77] + "..." | ||
| formatted_pairs.append(f" {key}: {truncated}") | ||
| else: | ||
| # Add indentation for multi-line values | ||
| if len(lines) > 1: | ||
| indented_value = '\n '.join([''] + lines) | ||
| formatted_pairs.append(f" {key}:{indented_value}") | ||
| else: | ||
| formatted_pairs.append(f" {key}: {value_str}") | ||
| else: | ||
| # Truncate and add summary | ||
| truncated_lines = lines[:threshold] | ||
| remaining_lines = len(lines) - threshold | ||
| indented_value = '\n '.join([''] + truncated_lines) | ||
| formatted_pairs.append(f" {key}:{indented_value}\n [+{remaining_lines} more lines]") | ||
|
|
||
| return "{\n" + ",\n".join(formatted_pairs) + "\n}" | ||
|
|
||
|
|
||
| class ToolMonitoringMiddleware(AgentMiddleware): | ||
| def __init__(self, *, persona: BasePersona): | ||
| self.stream_message = persona.stream_message | ||
| self.log = persona.log | ||
|
|
||
| async def awrap_tool_call( | ||
| self, | ||
| request: ToolCallRequest, | ||
| handler: Callable[[ToolCallRequest], ToolMessage | Command], | ||
| ) -> ToolMessage | Command: | ||
| args = format_tool_args_compact(request.tool_call['args']) | ||
| self.log.info(f"{request.tool_call['name']}({args})") | ||
|
|
||
| try: | ||
| result = await handler(request) | ||
| self.log.info(f"{request.tool_call['name']} Done!") | ||
| return result | ||
| except Exception as e: | ||
| self.log.info(f"{request.tool_call['name']} failed: {e}") | ||
| return ToolMessage( | ||
| tool_call_id=request.tool_call["id"], status="error", content=f"{e}" | ||
| ) | ||
|
|
||
|
|
||
| class JupyternautPersona(BasePersona): | ||
|
|
@@ -28,11 +110,45 @@ def defaults(self): | |
| system_prompt="...", | ||
| ) | ||
|
|
||
| async def get_memory_store(self): | ||
| if not hasattr(self, "_memory_store"): | ||
| conn = await aiosqlite.connect(MEMORY_STORE_PATH, check_same_thread=False) | ||
| self._memory_store = AsyncSqliteSaver(conn) | ||
| return self._memory_store | ||
|
|
||
| def get_tools(self): | ||
| tools = nb_toolkit | ||
| tools += jlab_toolkit | ||
| return nb_toolkit | ||
|
|
||
| async def get_agent(self, model_id: str, model_args, system_prompt: str): | ||
| model = ChatLiteLLM(**model_args, model_id=model_id, streaming=True) | ||
| memory_store = await self.get_memory_store() | ||
|
|
||
| if not hasattr(self, "search_tool"): | ||
| self.search_tool = FilesystemFileSearchMiddleware( | ||
| root_path=self.parent.root_dir | ||
| ) | ||
| if not hasattr(self, "shell_tool"): | ||
| self.shell_tool = ShellToolMiddleware(workspace_root=self.parent.root_dir) | ||
| if not hasattr(self, "tool_call_handler"): | ||
| self.tool_call_handler = ToolMonitoringMiddleware( | ||
| persona=self | ||
| ) | ||
|
|
||
| return create_agent( | ||
| model, | ||
| system_prompt=system_prompt, | ||
| checkpointer=memory_store, | ||
| tools=self.get_tools(), # notebook and jlab tools | ||
| middleware=[self.shell_tool, self.tool_call_handler], | ||
| ) | ||
|
|
||
| async def process_message(self, message: Message) -> None: | ||
| if not hasattr(self, 'config_manager'): | ||
| if not hasattr(self, "config_manager"): | ||
| self.send_message( | ||
| "Jupyternaut requires the `jupyter_ai_jupyternaut` server extension package.\n\n", | ||
| "Please make sure to first install that package in your environment & restart the server." | ||
| "Please make sure to first install that package in your environment & restart the server.", | ||
| ) | ||
| if not self.config_manager.chat_model: | ||
| self.send_message( | ||
|
|
@@ -43,65 +159,44 @@ async def process_message(self, message: Message) -> None: | |
|
|
||
| model_id = self.config_manager.chat_model | ||
| model_args = self.config_manager.chat_model_args | ||
| context_as_messages = self.get_context_as_messages(model_id, message) | ||
| response_aiter = await acompletion( | ||
| **model_args, | ||
| model=model_id, | ||
| messages=[ | ||
| *context_as_messages, | ||
| { | ||
| "role": "user", | ||
| "content": message.body, | ||
| }, | ||
| ], | ||
| stream=True, | ||
| system_prompt = self.get_system_prompt(model_id=model_id, message=message) | ||
| agent = await self.get_agent( | ||
| model_id=model_id, model_args=model_args, system_prompt=system_prompt | ||
| ) | ||
|
|
||
| async def create_aiter(): | ||
| async for token, metadata in agent.astream( | ||
| {"messages": [{"role": "user", "content": message.body}]}, | ||
| {"configurable": {"thread_id": self.ychat.get_id()}}, | ||
| stream_mode="messages", | ||
|
Comment on lines
+169
to
+171
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (non-blocking) Since we're only adding to the SQLite checkpointer when this persona is called, does this mean that Jupyternaut will lack context on messages not routed to Jupyternaut? For example, consider the following chat: This is fine for now, just checking to see if I understand the current behavior.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct, we need a shared memory manager (or a store) in persona manager or base persona that enables personas to write messages for shared context along with an API to load the shared context. |
||
| ): | ||
| node = metadata["langgraph_node"] | ||
| content_blocks = token.content_blocks | ||
| if ( | ||
| node == "model" | ||
| and content_blocks | ||
| ): | ||
| if token.text: | ||
| yield token.text | ||
|
|
||
| response_aiter = create_aiter() | ||
| await self.stream_message(response_aiter) | ||
|
|
||
| def get_context_as_messages( | ||
| def get_system_prompt( | ||
| self, model_id: str, message: Message | ||
| ) -> list[dict[str, Any]]: | ||
| """ | ||
| Returns the current context, including attachments and recent messages, | ||
| as a list of messages accepted by `litellm.acompletion()`. | ||
| Returns the system prompt, including attachments as a string. | ||
| """ | ||
| system_msg_args = JupyternautSystemPromptArgs( | ||
| model_id=model_id, | ||
| persona_name=self.name, | ||
| context=self.process_attachments(message), | ||
| ).model_dump() | ||
|
|
||
| system_msg = { | ||
| "role": "system", | ||
| "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), | ||
| } | ||
|
|
||
| context_as_messages = [system_msg, *self._get_history_as_messages()] | ||
| return context_as_messages | ||
|
|
||
| def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: | ||
| """ | ||
| Returns the current history as a list of messages accepted by | ||
| `litellm.acompletion()`. | ||
| """ | ||
| # TODO: consider bounding history based on message size (e.g. total | ||
| # char/token count) instead of message count. | ||
| all_messages = self.ychat.get_messages() | ||
|
|
||
| # gather last k * 2 messages and return | ||
| # we exclude the last message since that is the human message just | ||
| # submitted by a user. | ||
| start_idx = 0 if k is None else -2 * k - 1 | ||
| recent_messages: list[Message] = all_messages[start_idx:-1] | ||
|
|
||
| history: list[dict[str, Any]] = [] | ||
| for msg in recent_messages: | ||
| role = ( | ||
| "assistant" | ||
| if msg.sender.startswith("jupyter-ai-personas::") | ||
| else "system" if msg.sender == SYSTEM_USERNAME else "user" | ||
| ) | ||
| history.append({"role": role, "content": msg.body}) | ||
| return JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args) | ||
|
|
||
| return history | ||
| def shutdown(self): | ||
| if self._memory_store: | ||
| self.parent.event_loop.create_task(self._memory_store.conn.close()) | ||
| super().shutdown() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| """Tools that provide code execution features""" | ||
|
|
||
| import asyncio | ||
| import shlex | ||
| from typing import Optional | ||
|
|
||
|
|
||
| async def bash(command: str, timeout: Optional[int] = None) -> str: | ||
| """Executes a bash command and returns the result | ||
|
|
||
| Args: | ||
| command: The bash command to execute | ||
| timeout: Optional timeout in seconds | ||
|
|
||
| Returns: | ||
| The command output (stdout and stderr combined) | ||
| """ | ||
|
|
||
| proc = await asyncio.create_subprocess_exec( | ||
| *shlex.split(command), | ||
| stdout=asyncio.subprocess.PIPE, | ||
| stderr=asyncio.subprocess.PIPE, | ||
| ) | ||
|
|
||
| try: | ||
| stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout) | ||
| output = stdout.decode("utf-8") | ||
| error = stderr.decode("utf-8") | ||
|
|
||
| if proc.returncode != 0: | ||
| if error: | ||
| return f"Error: {error}" | ||
| return f"Command failed with exit code {proc.returncode}" | ||
|
|
||
| return output if output else "Command executed successfully with no output." | ||
| except asyncio.TimeoutError: | ||
| proc.kill() | ||
| return f"Command timed out after {timeout} seconds" | ||
|
|
||
|
|
||
| toolkit = [bash] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the correct parameter should be
model=model_id(model instead of model_id), according to theChatLiteLLMattribute.When testing this PR, the backend is complaining about missing OpenAi API key. Trying to debug it, it seems that the model setup in
ChatLiteLLMis always the default one,gpt-3.5-turbo.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened #19 to fix it.