Skip to content

Commit d504b34

Browse files
committed
WIP: first working copy of Jupyternaut as an agent
1 parent 3165e72 commit d504b34

File tree

3 files changed

+222
-57
lines changed

3 files changed

+222
-57
lines changed

packages/jupyter-ai/jupyter_ai/personas/base_persona.py

Lines changed: 174 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,39 @@
1+
from __future__ import annotations
12
import asyncio
23
import os
34
from abc import ABC, ABCMeta, abstractmethod
45
from dataclasses import asdict
56
from logging import Logger
67
from time import time
7-
from typing import TYPE_CHECKING, Any, Optional
8+
from typing import TYPE_CHECKING, Any, Optional, Tuple
89

910
from jupyter_ai.config_manager import ConfigManager
1011
from jupyterlab_chat.models import Message, NewMessage, User
1112
from jupyterlab_chat.ychat import YChat
13+
from litellm import ModelResponseStream, supports_function_calling
14+
from litellm.utils import function_to_dict
1215
from pydantic import BaseModel
1316
from traitlets import MetaHasTraits
1417
from traitlets.config import LoggingConfigurable
1518

1619
from .persona_awareness import PersonaAwareness
20+
from ..litellm_utils import ToolCallList, ResolvedToolCall
21+
22+
# Import toolkits
23+
from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit
24+
from jupyter_ai_tools.toolkits.code_execution import toolkit as codeexec_toolkit
25+
from jupyter_ai_tools.toolkits.git import toolkit as git_toolkit
1726

18-
# prevents a circular import
19-
# types imported under this block have to be surrounded in single quotes on use
2027
if TYPE_CHECKING:
2128
from collections.abc import AsyncIterator
22-
23-
from litellm import ModelResponseStream
24-
2529
from .persona_manager import PersonaManager
30+
from ..tools import Toolkit
2631

32+
DEFAULT_TOOLKITS: dict[str, Toolkit] = {
33+
"fs": fs_toolkit,
34+
"codeexec": codeexec_toolkit,
35+
"git": git_toolkit,
36+
}
2737

2838
class PersonaDefaults(BaseModel):
2939
"""
@@ -237,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]:
237247

238248
async def stream_message(
239249
self, reply_stream: "AsyncIterator[ModelResponseStream | str]"
240-
) -> None:
250+
) -> Tuple[ResolvedToolCall, ToolCallList]:
241251
"""
242252
Takes an async iterator, dubbed the 'reply stream', and streams it to a
243253
new message by this persona in the YChat. The async iterator may yield
@@ -247,21 +257,36 @@ async def stream_message(
247257
stream, then continuously updates it until the stream is closed.
248258
249259
- Automatically manages its awareness state to show writing status.
260+
261+
Returns a list of `ResolvedToolCall` objects. If this list is not empty,
262+
the persona should run these tools.
250263
"""
251264
stream_id: Optional[str] = None
252265
stream_interrupted = False
253266
try:
254267
self.awareness.set_local_state_field("isWriting", True)
255-
async for chunk in reply_stream:
256-
# Coerce LiteLLM stream chunk to a string delta
257-
if not isinstance(chunk, str):
258-
chunk = chunk.choices[0].delta.content
268+
toolcall_list = ToolCallList()
269+
resolved_toolcalls: list[ResolvedToolCall] = []
259270

260-
# LiteLLM streams always terminate with an empty chunk, so we
261-
# ignore and continue when this occurs.
262-
if not chunk:
271+
async for chunk in reply_stream:
272+
# Compute `content_delta` and `tool_calls_delta` based on the
273+
# type of object yielded by `reply_stream`.
274+
if isinstance(chunk, ModelResponseStream):
275+
delta = chunk.choices[0].delta
276+
content_delta = delta.content
277+
toolcalls_delta = delta.tool_calls
278+
elif isinstance(chunk, str):
279+
content_delta = chunk
280+
toolcalls_delta = None
281+
else:
282+
raise Exception(f"Unrecognized type in stream_message(): {type(chunk)}")
283+
284+
# LiteLLM streams always terminate with an empty chunk, so
285+
# continue in this case.
286+
if not (content_delta or toolcalls_delta):
263287
continue
264288

289+
# Terminate the stream if the user requested it.
265290
if (
266291
stream_id
267292
and stream_id in self.message_interrupted.keys()
@@ -280,34 +305,46 @@ async def stream_message(
280305
stream_interrupted = True
281306
break
282307

283-
if not stream_id:
284-
stream_id = self.ychat.add_message(
285-
NewMessage(body="", sender=self.id)
308+
# Append `content_delta` to the existing message.
309+
if content_delta:
310+
# Start the stream with an empty message on the initial reply.
311+
# Bind the new message ID to `stream_id`.
312+
if not stream_id:
313+
stream_id = self.ychat.add_message(
314+
NewMessage(body="", sender=self.id)
315+
)
316+
self.message_interrupted[stream_id] = asyncio.Event()
317+
self.awareness.set_local_state_field("isWriting", stream_id)
318+
assert stream_id
319+
320+
self.ychat.update_message(
321+
Message(
322+
id=stream_id,
323+
body=content_delta,
324+
time=time(),
325+
sender=self.id,
326+
raw_time=False,
327+
),
328+
append=True,
286329
)
287-
self.message_interrupted[stream_id] = asyncio.Event()
288-
self.awareness.set_local_state_field("isWriting", stream_id)
289-
290-
assert stream_id
291-
self.ychat.update_message(
292-
Message(
293-
id=stream_id,
294-
body=chunk,
295-
time=time(),
296-
sender=self.id,
297-
raw_time=False,
298-
),
299-
append=True,
300-
)
330+
if toolcalls_delta:
331+
toolcall_list += toolcalls_delta
332+
333+
# After the reply stream is complete, resolve the list of tool calls.
334+
resolved_toolcalls = toolcall_list.resolve()
301335
except Exception as e:
302336
self.log.error(
303337
f"Persona '{self.name}' encountered an exception printed below when attempting to stream output."
304338
)
305339
self.log.exception(e)
306340
finally:
341+
# Reset local state
307342
self.awareness.set_local_state_field("isWriting", False)
308-
if stream_id:
309-
# if stream was interrupted, add a tombstone
310-
if stream_interrupted:
343+
self.message_interrupted.pop(stream_id, None)
344+
345+
# If stream was interrupted, add a tombstone and return `[]`,
346+
# indicating that no tools should be run afterwards.
347+
if stream_id and stream_interrupted:
311348
stream_tombstone = "\n\n(AI response stopped by user)"
312349
self.ychat.update_message(
313350
Message(
@@ -319,8 +356,15 @@ async def stream_message(
319356
),
320357
append=True,
321358
)
322-
if stream_id in self.message_interrupted.keys():
323-
del self.message_interrupted[stream_id]
359+
return None
360+
361+
# Otherwise return the resolved list.
362+
if len(resolved_toolcalls):
363+
count = len(resolved_toolcalls)
364+
names = sorted([tc.function.name for tc in resolved_toolcalls])
365+
self.log.info(f"AI response triggered {count} tool calls: {names}")
366+
return resolved_toolcalls, toolcall_list
367+
324368

325369
def send_message(self, body: str) -> None:
326370
"""
@@ -361,7 +405,7 @@ def get_mcp_config(self) -> dict[str, Any]:
361405
Returns the MCP config for the current chat.
362406
"""
363407
return self.parent.get_mcp_config()
364-
408+
365409
def process_attachments(self, message: Message) -> Optional[str]:
366410
"""
367411
Process file attachments in the message and return their content as a string.
@@ -431,6 +475,99 @@ def resolve_attachment_to_path(self, attachment_id: str) -> Optional[str]:
431475
self.log.error(f"Failed to resolve attachment {attachment_id}: {e}")
432476
return None
433477

478+
def get_tools(self, model_id: str) -> list[dict]:
479+
"""
480+
Returns the `tools` parameter which should be passed to
481+
`litellm.acompletion()` for a given LiteLLM model ID.
482+
483+
If the model does not support tool-calling, this method returns an empty
484+
list. Otherwise, it returns the list of tools available in the current
485+
environment. These may include:
486+
487+
- The default set of tool functions in Jupyter AI, defined in the
488+
`jupyter_ai_tools` package.
489+
490+
- (TODO) Tools provided by MCP server configuration, if any.
491+
492+
- (TODO) Web search.
493+
494+
- (TODO) File search using vector store IDs.
495+
496+
TODO: cache this
497+
498+
TODO: Implement some permissions system so users can control what tools
499+
are allowable.
500+
501+
NOTE: The returned list is expected by LiteLLM to conform to the `tools`
502+
parameter defintiion defined by the OpenAI API:
503+
https://platform.openai.com/docs/guides/tools#available-tools
504+
505+
NOTE: This API is a WIP and is very likely to change.
506+
"""
507+
# Return early if the model does not support tool calling
508+
if not supports_function_calling(model=model_id):
509+
return []
510+
511+
tool_descriptions = []
512+
513+
# Get all tools from `jupyter_ai_tools` and store their object descriptions
514+
for toolkit_name, toolkit in DEFAULT_TOOLKITS.items():
515+
# TODO: make these tool permissions configurable.
516+
for tool in toolkit.get_tools():
517+
# Here, we are using a util function from LiteLLM to coerce
518+
# each `Tool` struct into a tool description dictionary expected
519+
# by LiteLLM.
520+
desc = {
521+
"type": "function",
522+
"function": function_to_dict(tool.callable),
523+
}
524+
525+
# Prepend the toolkit name to each function name, hopefully
526+
# ensuring every tool function has a unique name.
527+
# e.g. 'git_add' => 'git__git_add'
528+
#
529+
# TODO: Actually ensure this instead of hoping.
530+
desc['function']['name'] = f"{toolkit_name}__{desc['function']['name']}"
531+
tool_descriptions.append(desc)
532+
533+
# Finally, return the tool descriptions
534+
return tool_descriptions
535+
536+
537+
async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]:
538+
"""
539+
Runs the tools specified in the list of tool calls returned by
540+
`self.stream_message()`. Returns a list of dictionaries
541+
`toolcall_outputs: list[dict]`, which should be appended directly to the
542+
message history on the next invocation of the LLM.
543+
"""
544+
if not len(tools):
545+
return []
546+
547+
tool_outputs: list[dict] = []
548+
for tool_call in tools:
549+
# Get tool definition from the correct toolkit
550+
toolkit_name, tool_name = tool_call.function.name.split("__")
551+
assert toolkit_name in DEFAULT_TOOLKITS
552+
tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name)
553+
554+
# Run tool and store its output
555+
output = await tool_defn.callable(**tool_call.function.arguments)
556+
557+
# Store the tool output in a dictionary accepted by LiteLLM
558+
output_dict = {
559+
"tool_call_id": tool_call.id,
560+
"role": "tool",
561+
"name": tool_call.function.name,
562+
"content": output,
563+
}
564+
tool_outputs.append(output_dict)
565+
566+
self.log.info(f"Ran {len(tools)} tool functions.")
567+
return tool_outputs
568+
569+
570+
434571
def shutdown(self) -> None:
435572
"""
436573
Shuts the persona down. This method should:

packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE,
1010
JupyternautSystemPromptArgs,
1111
)
12+
from ...litellm_utils import ResolvedToolCall
1213

1314

1415
class JupyternautPersona(BasePersona):
@@ -37,22 +38,35 @@ async def process_message(self, message: Message) -> None:
3738
return
3839

3940
model_id = self.config_manager.chat_model
40-
model_args = self.config_manager.chat_model_args
41-
context_as_messages = self.get_context_as_messages(model_id, message)
42-
response_aiter = await acompletion(
43-
**model_args,
44-
model=model_id,
45-
messages=[
46-
*context_as_messages,
47-
{
48-
"role": "user",
49-
"content": message.body,
50-
},
51-
],
52-
stream=True,
53-
)
5441

55-
await self.stream_message(response_aiter)
42+
# `True` on the first LLM invocation, `False` on all invocations after.
43+
initial_invocation = True
44+
# List of tool calls requested by the LLM in the previous invocaiton.
45+
tool_calls: list[ResolvedToolCall] = []
46+
tool_call_list = None
47+
# List of tool call outputs computed in the previous invocation.
48+
tool_call_outputs: list[dict] = []
49+
50+
# Loop until the AI is complete running all its tools.
51+
while initial_invocation or len(tool_call_outputs):
52+
messages = self.get_context_as_messages(model_id, message)
53+
54+
# TODO: Find a better way to track tool calls
55+
if not initial_invocation and tool_calls:
56+
self.log.error(messages[-1])
57+
messages[-1]['tool_calls'] = tool_call_list._aggregate
58+
messages.extend(tool_call_outputs)
59+
60+
self.log.error(messages)
61+
response_aiter = await acompletion(
62+
model=model_id,
63+
messages=messages,
64+
tools=self.get_tools(model_id),
65+
stream=True,
66+
)
67+
tool_calls, tool_call_list = await self.stream_message(response_aiter)
68+
initial_invocation = False
69+
tool_call_outputs = await self.run_tools(tool_calls)
5670

5771
def get_context_as_messages(
5872
self, model_id: str, message: Message
@@ -79,16 +93,17 @@ def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]
7993
"""
8094
Returns the current history as a list of messages accepted by
8195
`litellm.acompletion()`.
96+
97+
NOTE: You should usually call the public `get_context_as_messages()`
98+
method instead.
8299
"""
83100
# TODO: consider bounding history based on message size (e.g. total
84101
# char/token count) instead of message count.
85102
all_messages = self.ychat.get_messages()
86103

87104
# gather last k * 2 messages and return
88-
# we exclude the last message since that is the human message just
89-
# submitted by a user.
90-
start_idx = 0 if k is None else -2 * k - 1
91-
recent_messages: list[Message] = all_messages[start_idx:-1]
105+
start_idx = 0 if k is None else -2 * k
106+
recent_messages: list[Message] = all_messages[start_idx:]
92107

93108
history: list[dict[str, Any]] = []
94109
for msg in recent_messages:

0 commit comments

Comments
 (0)