Skip to content

Commit 5aa46bf

Browse files
committed
clean up tool calling flow & show in chat
1 parent d504b34 commit 5aa46bf

File tree

6 files changed

+151
-112
lines changed

6 files changed

+151
-112
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .toolcall_list import ToolCallList
2-
from .toolcall_types import *
1+
from .toolcall_list import *
2+
from .streaming_utils import *
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from pydantic import BaseModel
2+
from .toolcall_list import ToolCallList
3+
4+
class StreamResult(BaseModel):
5+
id: str
6+
"""
7+
ID of the new message.
8+
"""
9+
10+
tool_calls: ToolCallList
11+
"""
12+
Tool calls requested by the LLM in its streamed response.
13+
"""

packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,61 @@
11
from litellm.utils import ChatCompletionDeltaToolCall, Function
22
import json
3+
from pydantic import BaseModel
4+
from typing import Any
35

4-
from .toolcall_types import ResolvedToolCall, ResolvedFunction
6+
class ResolvedFunction(BaseModel):
7+
"""
8+
A type-safe, parsed representation of `litellm.utils.Function`.
9+
"""
10+
11+
name: str
12+
"""
13+
Name of the tool function to be called.
14+
15+
TODO: Check if this attribute is defined for non-function tools, e.g. tools
16+
provided by a MCP server. The docstring on `litellm.utils.Function` implies
17+
that `name` may be `None`.
18+
"""
19+
20+
arguments: dict[str, Any]
21+
"""
22+
Arguments to the tool function, as a dictionary.
23+
"""
24+
25+
26+
class ResolvedToolCall(BaseModel):
27+
"""
28+
A type-safe, parsed representation of
29+
`litellm.utils.ChatCompletionDeltaToolCall`.
30+
"""
31+
32+
id: str | None
33+
"""
34+
The ID of the tool call. This should always be provided by LiteLLM, this
35+
type is left optional as we do not use this attribute.
36+
"""
37+
38+
type: str
39+
"""
40+
The 'type' of tool call. Usually 'function'.
541
6-
class ToolCallList():
42+
TODO: Make this a union of string literals to ensure we are handling every
43+
potential type of tool call.
44+
"""
45+
46+
function: ResolvedFunction
47+
"""
48+
The resolved function. See `ResolvedFunction` for more info.
49+
"""
50+
51+
index: int
52+
"""
53+
The index of this tool call.
54+
55+
This is usually 0 unless the LLM supports parallel tool calling.
56+
"""
57+
58+
class ToolCallList(BaseModel):
759
"""
860
A helper object that defines a custom `__iadd__()` method which accepts a
961
`tool_call_deltas: list[ChatCompletionDeltaToolCall]` argument. This class
@@ -27,14 +79,7 @@ class ToolCallList():
2779
```
2880
"""
2981

30-
_aggregate: list[ChatCompletionDeltaToolCall]
31-
32-
def __init__(self):
33-
self.size = None
34-
35-
# Initialize `_aggregate`
36-
self._aggregate = []
37-
82+
_aggregate: list[ChatCompletionDeltaToolCall] = []
3883

3984
def __iadd__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList':
4085
"""
@@ -116,6 +161,13 @@ def resolve(self) -> list[ResolvedToolCall]:
116161
resolved_toolcalls.append(resolved_toolcall)
117162

118163
return resolved_toolcalls
119-
120-
121-
164+
165+
def to_json(self) -> list[dict[str, Any]]:
166+
"""
167+
Returns the list of tool calls as a Python dictionary that can be
168+
JSON-serialized.
169+
"""
170+
return [
171+
model.model_dump() for model in self._aggregate
172+
]
173+

packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

packages/jupyter-ai/jupyter_ai/personas/base_persona.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import asdict
66
from logging import Logger
77
from time import time
8-
from typing import TYPE_CHECKING, Any, Optional, Tuple
8+
from typing import TYPE_CHECKING, Any, Optional
99

1010
from jupyter_ai.config_manager import ConfigManager
1111
from jupyterlab_chat.models import Message, NewMessage, User
@@ -17,7 +17,7 @@
1717
from traitlets.config import LoggingConfigurable
1818

1919
from .persona_awareness import PersonaAwareness
20-
from ..litellm_utils import ToolCallList, ResolvedToolCall
20+
from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall
2121

2222
# Import toolkits
2323
from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit
@@ -247,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]:
247247

248248
async def stream_message(
249249
self, reply_stream: "AsyncIterator[ModelResponseStream | str]"
250-
) -> Tuple[ResolvedToolCall, ToolCallList]:
250+
) -> StreamResult:
251251
"""
252252
Takes an async iterator, dubbed the 'reply stream', and streams it to a
253253
new message by this persona in the YChat. The async iterator may yield
@@ -263,12 +263,21 @@ async def stream_message(
263263
"""
264264
stream_id: Optional[str] = None
265265
stream_interrupted = False
266+
tool_calls = ToolCallList()
266267
try:
267268
self.awareness.set_local_state_field("isWriting", True)
268-
toolcall_list = ToolCallList()
269-
resolved_toolcalls: list[ResolvedToolCall] = []
270269

271270
async for chunk in reply_stream:
271+
# Start the stream with an empty message on the initial reply.
272+
# Bind the new message ID to `stream_id`.
273+
if not stream_id:
274+
stream_id = self.ychat.add_message(
275+
NewMessage(body="", sender=self.id)
276+
)
277+
self.message_interrupted[stream_id] = asyncio.Event()
278+
self.awareness.set_local_state_field("isWriting", stream_id)
279+
assert stream_id
280+
272281
# Compute `content_delta` and `tool_calls_delta` based on the
273282
# type of object yielded by `reply_stream`.
274283
if isinstance(chunk, ModelResponseStream):
@@ -307,16 +316,6 @@ async def stream_message(
307316

308317
# Append `content_delta` to the existing message.
309318
if content_delta:
310-
# Start the stream with an empty message on the initial reply.
311-
# Bind the new message ID to `stream_id`.
312-
if not stream_id:
313-
stream_id = self.ychat.add_message(
314-
NewMessage(body="", sender=self.id)
315-
)
316-
self.message_interrupted[stream_id] = asyncio.Event()
317-
self.awareness.set_local_state_field("isWriting", stream_id)
318-
assert stream_id
319-
320319
self.ychat.update_message(
321320
Message(
322321
id=stream_id,
@@ -328,10 +327,8 @@ async def stream_message(
328327
append=True,
329328
)
330329
if toolcalls_delta:
331-
toolcall_list += toolcalls_delta
330+
tool_calls += toolcalls_delta
332331

333-
# After the reply stream is complete, resolve the list of tool calls.
334-
resolved_toolcalls = toolcall_list.resolve()
335332
except Exception as e:
336333
self.log.error(
337334
f"Persona '{self.name}' encountered an exception printed below when attempting to stream output."
@@ -358,12 +355,17 @@ async def stream_message(
358355
)
359356
return None
360357

361-
# Otherwise return the resolved list.
358+
# TODO: determine where this should live
359+
resolved_toolcalls = tool_calls.resolve()
362360
if len(resolved_toolcalls):
363361
count = len(resolved_toolcalls)
364362
names = sorted([tc.function.name for tc in resolved_toolcalls])
365363
self.log.info(f"AI response triggered {count} tool calls: {names}")
366-
return resolved_toolcalls, toolcall_list
364+
365+
return StreamResult(
366+
id=stream_id,
367+
tool_calls=tool_calls
368+
)
367369

368370

369371
def send_message(self, body: str) -> None:
@@ -552,7 +554,9 @@ async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]:
552554
tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name)
553555

554556
# Run tool and store its output
555-
output = await tool_defn.callable(**tool_call.function.arguments)
557+
output = tool_defn.callable(**tool_call.function.arguments)
558+
if asyncio.iscoroutine(output):
559+
output = await output
556560

557561
# Store the tool output in a dictionary accepted by LiteLLM
558562
output_dict = {

packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import Any, Optional
2+
import time
3+
import json
24

35
from jupyterlab_chat.models import Message
46
from litellm import acompletion
@@ -9,7 +11,6 @@
911
JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE,
1012
JupyternautSystemPromptArgs,
1113
)
12-
from ...litellm_utils import ResolvedToolCall
1314

1415

1516
class JupyternautPersona(BasePersona):
@@ -39,34 +40,60 @@ async def process_message(self, message: Message) -> None:
3940

4041
model_id = self.config_manager.chat_model
4142

42-
# `True` on the first LLM invocation, `False` on all invocations after.
43-
initial_invocation = True
44-
# List of tool calls requested by the LLM in the previous invocaiton.
45-
tool_calls: list[ResolvedToolCall] = []
46-
tool_call_list = None
43+
# `True` before the first LLM response is sent, `False` afterwards.
44+
initial_response = True
4745
# List of tool call outputs computed in the previous invocation.
4846
tool_call_outputs: list[dict] = []
4947

50-
# Loop until the AI is complete running all its tools.
51-
while initial_invocation or len(tool_call_outputs):
52-
messages = self.get_context_as_messages(model_id, message)
53-
54-
# TODO: Find a better way to track tool calls
55-
if not initial_invocation and tool_calls:
56-
self.log.error(messages[-1])
57-
messages[-1]['tool_calls'] = tool_call_list._aggregate
58-
messages.extend(tool_call_outputs)
48+
# Initialize list of messages, including history and context
49+
messages: list[dict] = self.get_context_as_messages(model_id, message)
5950

60-
self.log.error(messages)
51+
# Loop until the AI is complete running all its tools.
52+
while initial_response or len(tool_call_outputs):
53+
# Stream message to the chat
6154
response_aiter = await acompletion(
6255
model=model_id,
6356
messages=messages,
6457
tools=self.get_tools(model_id),
6558
stream=True,
6659
)
67-
tool_calls, tool_call_list = await self.stream_message(response_aiter)
68-
initial_invocation = False
69-
tool_call_outputs = await self.run_tools(tool_calls)
60+
result = await self.stream_message(response_aiter)
61+
initial_response = False
62+
63+
# Append new reply to `messages`
64+
reply = self.ychat.get_message(result.id)
65+
tool_calls_json = result.tool_calls.to_json()
66+
messages.append({
67+
"role": "assistant",
68+
"content": reply.body,
69+
"tool_calls": tool_calls_json
70+
})
71+
72+
# Show tool call requests to YChat (not synced with `messages`)
73+
if len(tool_calls_json):
74+
self.ychat.update_message(Message(
75+
id=result.id,
76+
body=f"\n\n```\n{json.dumps(tool_calls_json, indent=2)}\n```\n",
77+
sender=self.id,
78+
time=time.time(),
79+
raw_time=False
80+
), append=True)
81+
82+
# Run tools and append outputs to `messages`
83+
tool_call_outputs = await self.run_tools(result.tool_calls.resolve())
84+
messages.extend(tool_call_outputs)
85+
86+
# Add tool call outputs to YChat (not synced with `messages`)
87+
if tool_call_outputs:
88+
self.ychat.update_message(Message(
89+
id=result.id,
90+
body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n",
91+
sender=self.id,
92+
time=time.time(),
93+
raw_time=False
94+
), append=True)
95+
96+
7097

7198
def get_context_as_messages(
7299
self, model_id: str, message: Message

0 commit comments

Comments
 (0)