Skip to content

Commit 3165e72

Browse files
committed
add 'litellm_utils' module
1 parent e4df0f1 commit 3165e72

File tree

4 files changed

+232
-0
lines changed

4 files changed

+232
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .toolcall_list import ToolCallList
2+
from .toolcall_types import *
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from litellm.utils import ChatCompletionDeltaToolCall, Function
2+
from .toolcall_list import ToolCallList
3+
4+
class TestToolCallList():
5+
6+
def test_single_tool_stream(self):
7+
"""
8+
Asserts this class works against a sample response from Claude running a
9+
single tool.
10+
"""
11+
# Setup test
12+
ID = "toolu_01TzXi4nFJErYThcdhnixn7e"
13+
toolcall_list = ToolCallList()
14+
toolcall_list += [ChatCompletionDeltaToolCall(id=ID, function=Function(arguments='', name='ls'), type='function', index=0)]
15+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)]
16+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path', name=None), type='function', index=0)]
17+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='": "."}', name=None), type='function', index=0)]
18+
19+
# Verify the resolved list of calls
20+
resolved_toolcalls = toolcall_list.resolve()
21+
assert len(resolved_toolcalls) == 1
22+
assert resolved_toolcalls[0]
23+
24+
def test_two_tool_stream(self):
25+
"""
26+
Asserts this class works against a sample response from Claude running a
27+
two tools in parallel.
28+
"""
29+
# Setup test
30+
ID_0 = 'toolu_0141FrNfT2LJg6odqbrdmLM6'
31+
ID_1 = 'toolu_01DKqnaXVcyp1v1ABxhHC5Sg'
32+
toolcall_list = ToolCallList()
33+
toolcall_list += [ChatCompletionDeltaToolCall(id=ID_0, function=Function(arguments='', name='ls'), type='function', index=0)]
34+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)]
35+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path": ', name=None), type='function', index=0)]
36+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='"."}', name=None), type='function', index=0)]
37+
toolcall_list += [ChatCompletionDeltaToolCall(id=ID_1, function=Function(arguments='', name='bash'), type='function', index=1)]
38+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=1)]
39+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"com', name=None), type='function', index=1)]
40+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='mand": "ech', name=None), type='function', index=1)]
41+
toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='o \'hello\'"}', name=None), type='function', index=1)]
42+
43+
# Verify the resolved list of calls
44+
resolved_toolcalls = toolcall_list.resolve()
45+
assert len(resolved_toolcalls) == 2
46+
assert resolved_toolcalls[0].id == ID_0
47+
assert resolved_toolcalls[0].function.name == "ls"
48+
assert resolved_toolcalls[0].function.arguments == { "path": "." }
49+
assert resolved_toolcalls[1].id == ID_1
50+
assert resolved_toolcalls[1].function.name == "bash"
51+
assert resolved_toolcalls[1].function.arguments == { "command": "echo \'hello\'" }
52+
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from litellm.utils import ChatCompletionDeltaToolCall, Function
2+
import json
3+
4+
from .toolcall_types import ResolvedToolCall, ResolvedFunction
5+
6+
class ToolCallList():
7+
"""
8+
A helper object that defines a custom `__iadd__()` method which accepts a
9+
`tool_call_deltas: list[ChatCompletionDeltaToolCall]` argument. This class
10+
is used to aggregate the tool call deltas yielded from a LiteLLM response
11+
stream and produce a list of tool calls.
12+
13+
After all tool call deltas are added, the `process()` method may be called
14+
to return a list of resolved tool calls.
15+
16+
Example usage:
17+
18+
```py
19+
tool_call_list = ToolCallList()
20+
reply_stream = await litellm.acompletion(..., stream=True)
21+
22+
async for chunk in reply_stream:
23+
tool_call_delta = chunk.choices[0].delta.tool_calls
24+
tool_call_list += tool_call_delta
25+
26+
tool_call_list.resolve()
27+
```
28+
"""
29+
30+
_aggregate: list[ChatCompletionDeltaToolCall]
31+
32+
def __init__(self):
33+
self.size = None
34+
35+
# Initialize `_aggregate`
36+
self._aggregate = []
37+
38+
39+
def __iadd__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList':
40+
"""
41+
Adds a list of tool call deltas to this instance.
42+
43+
NOTE: This assumes the 'index' attribute on each entry in this list to
44+
be accurate. If this assumption doesn't hold, we will need to rework the
45+
logic here.
46+
"""
47+
if other is None:
48+
return self
49+
50+
# Iterate through each delta
51+
for delta in other:
52+
# Ensure `self._aggregate` is at least of size `delta.index + 1`
53+
for i in range(len(self._aggregate), delta.index + 1):
54+
self._aggregate.append(ChatCompletionDeltaToolCall(
55+
function=Function(arguments=""),
56+
index=i,
57+
))
58+
59+
# Find the corresponding target in the `self._aggregate` and add the
60+
# delta on top of it. In most cases, the value of aggregate
61+
# attribute is set as soon as any delta sets it to a non-`None`
62+
# value. However, `delta.function.arguments` is a string that should
63+
# be appended to the aggregate value of that attribute.
64+
target = self._aggregate[delta.index]
65+
if delta.type:
66+
target.type = delta.type
67+
if delta.id:
68+
target.id = delta.id
69+
if delta.function.name:
70+
target.function.name = delta.function.name
71+
if delta.function.arguments:
72+
target.function.arguments += delta.function.arguments
73+
74+
return self
75+
76+
77+
def __add__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList':
78+
"""
79+
Alias for `__iadd__()`.
80+
"""
81+
return self.__iadd__(other)
82+
83+
84+
def resolve(self) -> list[ResolvedToolCall]:
85+
"""
86+
Resolve the aggregated tool call delta lists into a list of tool calls.
87+
"""
88+
resolved_toolcalls: list[ResolvedToolCall] = []
89+
for i, raw_toolcall in enumerate(self._aggregate):
90+
# Verify entries are at the correct index in the aggregated list
91+
assert raw_toolcall.index == i
92+
93+
# Verify each tool call specifies the name of the tool to run.
94+
#
95+
# TODO: Check if this may cause a runtime error. The docstring on
96+
# `litellm.utils.Function` implies that `name` may be `None`.
97+
assert raw_toolcall.function.name
98+
99+
# Verify each tool call defines the type of tool it is calling.
100+
assert raw_toolcall.type is not None
101+
102+
# Parse the function argument string into a dictionary
103+
resolved_fn_args = json.loads(raw_toolcall.function.arguments)
104+
105+
# Add to the returned list
106+
resolved_fn = ResolvedFunction(
107+
name=raw_toolcall.function.name,
108+
arguments=resolved_fn_args
109+
)
110+
resolved_toolcall = ResolvedToolCall(
111+
id=raw_toolcall.id,
112+
type=raw_toolcall.type,
113+
index=i,
114+
function=resolved_fn
115+
)
116+
resolved_toolcalls.append(resolved_toolcall)
117+
118+
return resolved_toolcalls
119+
120+
121+
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
from pydantic import BaseModel
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from typing import Any
7+
8+
class ResolvedFunction(BaseModel):
9+
"""
10+
A type-safe, parsed representation of `litellm.utils.Function`.
11+
"""
12+
13+
name: str
14+
"""
15+
Name of the tool function to be called.
16+
17+
TODO: Check if this attribute is defined for non-function tools, e.g. tools
18+
provided by a MCP server. The docstring on `litellm.utils.Function` implies
19+
that `name` may be `None`.
20+
"""
21+
22+
arguments: dict
23+
"""
24+
Arguments to the tool function, as a dictionary.
25+
"""
26+
27+
class ResolvedToolCall(BaseModel):
28+
"""
29+
A type-safe, parsed representation of
30+
`litellm.utils.ChatCompletionDeltaToolCall`.
31+
"""
32+
33+
id: str | None
34+
"""
35+
The ID of the tool call. This should always be provided by LiteLLM, this
36+
type is left optional as we do not use this attribute.
37+
"""
38+
39+
type: str
40+
"""
41+
The 'type' of tool call. Usually 'function'.
42+
43+
TODO: Make this a union of string literals to ensure we are handling every
44+
potential type of tool call.
45+
"""
46+
47+
function: ResolvedFunction
48+
"""
49+
The resolved function. See `ResolvedFunction` for more info.
50+
"""
51+
52+
index: int
53+
"""
54+
The index of this tool call.
55+
56+
This is usually 0 unless the LLM supports parallel tool calling.
57+
"""

0 commit comments

Comments
 (0)