Skip to content

Commit 4059489

Browse files
committed
update tool calling APIs
1 parent 7ba285d commit 4059489

File tree

6 files changed

+70
-42
lines changed

6 files changed

+70
-42
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .toolcall_list import *
22
from .streaming_utils import *
3+
from .run_tools import *
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import asyncio
2+
from pydantic import BaseModel
3+
from .toolcall_list import ToolCallList
4+
from ..tools import Toolkit
5+
6+
7+
class ToolCallOutput(BaseModel):
8+
tool_call_id: str
9+
role: str = "tool"
10+
name: str
11+
content: str
12+
13+
14+
async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict]:
15+
"""
16+
Runs the tools specified in the list of tool calls returned by
17+
`self.stream_message()`.
18+
19+
Returns `list[ToolCallOutput]`. The outputs should be appended directly to
20+
the message history on the next request made to the LLM.
21+
"""
22+
tool_calls = tool_call_list.resolve()
23+
if not len(tool_calls):
24+
return []
25+
26+
tool_outputs: list[dict] = []
27+
for tool_call in tool_calls:
28+
# Get tool definition from the correct toolkit
29+
# TODO: validation?
30+
tool_name = tool_call.function.name
31+
tool_defn = toolkit.get_tool_unsafe(tool_name)
32+
33+
# Run tool and store its output
34+
try:
35+
output = tool_defn.callable(**tool_call.function.arguments)
36+
if asyncio.iscoroutine(output):
37+
output = await output
38+
except Exception as e:
39+
output = str(e)
40+
41+
# Store the tool output in a dictionary accepted by LiteLLM
42+
output_dict = {
43+
"tool_call_id": tool_call.id,
44+
"role": "tool",
45+
"name": tool_call.function.name,
46+
"content": output,
47+
}
48+
tool_outputs.append(output_dict)
49+
50+
return tool_outputs

packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class StreamResult(BaseModel):
77
ID of the new message.
88
"""
99

10-
tool_calls: ToolCallList
10+
tool_call_list: ToolCallList
1111
"""
1212
Tool calls requested by the LLM in its streamed response.
1313
"""

packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,8 @@ def to_json(self) -> list[dict[str, Any]]:
170170
return [
171171
model.model_dump() for model in self._aggregate
172172
]
173+
174+
175+
def __len__(self) -> int:
176+
return len(self._aggregate)
173177

packages/jupyter-ai/jupyter_ai/personas/base_persona.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from traitlets.config import LoggingConfigurable
1818

1919
from .persona_awareness import PersonaAwareness
20-
from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall
20+
from ..litellm_utils import ToolCallList, StreamResult, run_tools, ToolCallOutput
2121

2222
# Import toolkits
2323
from ..tools.default_toolkit import DEFAULT_TOOLKIT
@@ -255,7 +255,7 @@ async def stream_message(
255255
"""
256256
stream_id: Optional[str] = None
257257
stream_interrupted = False
258-
tool_calls = ToolCallList()
258+
tool_call_list = ToolCallList()
259259
try:
260260
self.awareness.set_local_state_field("isWriting", True)
261261

@@ -319,7 +319,7 @@ async def stream_message(
319319
append=True,
320320
)
321321
if toolcalls_delta:
322-
tool_calls += toolcalls_delta
322+
tool_call_list += toolcalls_delta
323323

324324
except Exception as e:
325325
self.log.error(
@@ -348,15 +348,13 @@ async def stream_message(
348348
return None
349349

350350
# TODO: determine where this should live
351-
resolved_toolcalls = tool_calls.resolve()
352-
if len(resolved_toolcalls):
353-
count = len(resolved_toolcalls)
354-
names = sorted([tc.function.name for tc in resolved_toolcalls])
355-
self.log.info(f"AI response triggered {count} tool calls: {names}")
351+
count = len(tool_call_list)
352+
if count > 0:
353+
self.log.info(f"AI response triggered {count} tool calls.")
356354

357355
return StreamResult(
358356
id=stream_id,
359-
tool_calls=tool_calls
357+
tool_call_list=tool_call_list
360358
)
361359

362360

@@ -520,38 +518,13 @@ def get_tools(self, model_id: str) -> list[dict]:
520518
return tool_descriptions
521519

522520

523-
async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]:
521+
async def run_tools(self, tool_call_list: ToolCallList) -> list[ToolCallOutput]:
524522
"""
525-
Runs the tools specified in the list of tool calls returned by
526-
`self.stream_message()`. Returns a list of dictionaries
527-
`toolcall_outputs: list[dict]`, which should be appended directly to the
528-
message history on the next invocation of the LLM.
523+
Runs the tools specified in a given tool call list using the default
524+
toolkit.
529525
"""
530-
if not len(tools):
531-
return []
532-
533-
tool_outputs: list[dict] = []
534-
for tool_call in tools:
535-
# Get tool definition from the correct toolkit
536-
# TODO: validation?
537-
tool_name = tool_call.function.name
538-
tool_defn = DEFAULT_TOOLKIT.get_tool_unsafe(tool_name)
539-
540-
# Run tool and store its output
541-
output = tool_defn.callable(**tool_call.function.arguments)
542-
if asyncio.iscoroutine(output):
543-
output = await output
544-
545-
# Store the tool output in a dictionary accepted by LiteLLM
546-
output_dict = {
547-
"tool_call_id": tool_call.id,
548-
"role": "tool",
549-
"name": tool_call.function.name,
550-
"content": output,
551-
}
552-
tool_outputs.append(output_dict)
553-
554-
self.log.info(f"Ran {len(tools)} tool functions.")
526+
tool_outputs = await run_tools(tool_call_list, toolkit=DEFAULT_TOOLKIT)
527+
self.log.info(f"Ran {len(tool_outputs)} tool functions.")
555528
return tool_outputs
556529

557530

packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def process_message(self, message: Message) -> None:
6262

6363
# Append new reply to `messages`
6464
reply = self.ychat.get_message(result.id)
65-
tool_calls_json = result.tool_calls.to_json()
65+
tool_calls_json = result.tool_call_list.to_json()
6666
messages.append({
6767
"role": "assistant",
6868
"content": reply.body,
@@ -80,7 +80,7 @@ async def process_message(self, message: Message) -> None:
8080
), append=True)
8181

8282
# Run tools and append outputs to `messages`
83-
tool_call_outputs = await self.run_tools(result.tool_calls.resolve())
83+
tool_call_outputs = await self.run_tools(result.tool_call_list)
8484
messages.extend(tool_call_outputs)
8585

8686
# Add tool call outputs to YChat (not synced with `messages`)

0 commit comments

Comments
 (0)