Skip to content

Commit 37140ff

Browse files
authored
Fix multi-part tool results in history (#112)
Chat.add_entry and Chat.append weren't updated when multi-part tool result support was added. Closes #109
1 parent eb8b8a5 commit 37140ff

File tree

2 files changed

+160
-8
lines changed

2 files changed

+160
-8
lines changed

src/lmstudio/history.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,13 @@ def _to_history_content(self) -> str:
135135
| ToolCallRequestData
136136
| ToolCallRequestDataDict
137137
)
138+
AssistantMultiPartInput = Iterable[AssistantResponseInput | ToolCallRequestInput]
138139
ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict
140+
ToolCallResultMultiPartInput = Iterable[ToolCallResultInput]
139141
ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict
140-
ChatMessageMultiPartInput = UserMessageMultiPartInput
142+
ChatMessageMultiPartInput = (
143+
UserMessageMultiPartInput | AssistantMultiPartInput | ToolCallResultMultiPartInput
144+
)
141145
AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput
142146

143147

@@ -251,9 +255,12 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
251255
if role == "user":
252256
messages = cast(AnyUserMessageInput, content)
253257
return self.add_user_message(messages)
258+
# Tool results accept multi-part content, so just forward it to that method
259+
if role == "tool":
260+
tool_results = cast(Iterable[ToolCallResultInput], content)
261+
return self.add_tool_results(tool_results)
254262
# Assistant responses consist of a text response with zero or more tool requests
255263
if role == "assistant":
256-
response: AssistantResponseInput
257264
if _is_chat_message_input(content):
258265
response = cast(AssistantResponseInput, content)
259266
return self.add_assistant_response(response)
@@ -263,7 +270,7 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
263270
raise LMStudioValueError(
264271
f"Unable to parse assistant response content: {content}"
265272
) from None
266-
response = response_content
273+
response = cast(AssistantResponseInput, response_content)
267274
tool_requests = cast(Iterable[ToolCallRequest], tool_request_contents)
268275
return self.add_assistant_response(response, tool_requests)
269276

@@ -276,17 +283,14 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
276283
content_item = content
277284
else:
278285
try:
279-
(content_item,) = content
286+
(content_item,) = cast(Iterable[ChatMessageInput], content)
280287
except ValueError:
281288
err_msg = f"{role!r} role does not support multi-part message content."
282289
raise LMStudioValueError(err_msg) from None
283290
match role:
284291
case "system":
285292
prompt = cast(SystemPromptInput, content_item)
286293
result = self.add_system_prompt(prompt)
287-
case "tool":
288-
tool_result = cast(ToolCallResultInput, content_item)
289-
result = self.add_tool_result(tool_result)
290294
case _:
291295
raise LMStudioValueError(f"Unknown history role: {role}")
292296
return result

tests/test_history.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from lmstudio.sdk_api import LMStudioOSError
1111
from lmstudio.schemas import DictObject
1212
from lmstudio.history import (
13+
AnyChatMessageDict,
1314
AnyChatMessageInput,
15+
AssistantMultiPartInput,
1416
Chat,
15-
AnyChatMessageDict,
1617
ChatHistoryData,
1718
ChatHistoryDataDict,
1819
LocalFileInput,
@@ -29,6 +30,10 @@
2930
LlmPredictionStats,
3031
PredictionResult,
3132
)
33+
from lmstudio._sdk_models import (
34+
ToolCallRequestDataDict,
35+
ToolCallResultDataDict,
36+
)
3237

3338
from .support import IMAGE_FILEPATH, check_sdk_error
3439

@@ -125,6 +130,51 @@
125130
"role": "system",
126131
"content": [{"type": "text", "text": "Structured text system prompt"}],
127132
},
133+
{
134+
"role": "assistant",
135+
"content": [
136+
{"type": "text", "text": "Example tool call request"},
137+
{
138+
"type": "toolCallRequest",
139+
"toolCallRequest": {
140+
"type": "function",
141+
"id": "114663647",
142+
"name": "example_tool_name",
143+
"arguments": {
144+
"n": 58013,
145+
"t": "value",
146+
},
147+
},
148+
},
149+
{
150+
"type": "toolCallRequest",
151+
"toolCallRequest": {
152+
"type": "function",
153+
"id": "114663648",
154+
"name": "another_example_tool_name",
155+
"arguments": {
156+
"n": 23,
157+
"t": "some other value",
158+
},
159+
},
160+
},
161+
],
162+
},
163+
{
164+
"role": "tool",
165+
"content": [
166+
{
167+
"type": "toolCallResult",
168+
"toolCallId": "114663647",
169+
"content": "example tool call result",
170+
},
171+
{
172+
"type": "toolCallResult",
173+
"toolCallId": "114663648",
174+
"content": "another example tool call result",
175+
},
176+
],
177+
},
128178
]
129179

130180
INPUT_HISTORY = {"messages": INPUT_ENTRIES}
@@ -214,6 +264,51 @@
214264
"role": "system",
215265
"content": [{"type": "text", "text": "Structured text system prompt"}],
216266
},
267+
{
268+
"role": "assistant",
269+
"content": [
270+
{"type": "text", "text": "Example tool call request"},
271+
{
272+
"type": "toolCallRequest",
273+
"toolCallRequest": {
274+
"type": "function",
275+
"id": "114663647",
276+
"name": "example_tool_name",
277+
"arguments": {
278+
"n": 58013,
279+
"t": "value",
280+
},
281+
},
282+
},
283+
{
284+
"type": "toolCallRequest",
285+
"toolCallRequest": {
286+
"type": "function",
287+
"id": "114663648",
288+
"name": "another_example_tool_name",
289+
"arguments": {
290+
"n": 23,
291+
"t": "some other value",
292+
},
293+
},
294+
},
295+
],
296+
},
297+
{
298+
"role": "tool",
299+
"content": [
300+
{
301+
"type": "toolCallResult",
302+
"toolCallId": "114663647",
303+
"content": "example tool call result",
304+
},
305+
{
306+
"type": "toolCallResult",
307+
"toolCallId": "114663648",
308+
"content": "another example tool call result",
309+
},
310+
],
311+
},
217312
]
218313

219314

@@ -271,6 +366,44 @@ def test_from_history_with_simple_text() -> None:
271366
"sizeBytes": 100,
272367
"fileType": "text/plain",
273368
}
369+
INPUT_TOOL_REQUESTS: list[ToolCallRequestDataDict] = [
370+
{
371+
"type": "toolCallRequest",
372+
"toolCallRequest": {
373+
"type": "function",
374+
"id": "114663647",
375+
"name": "example_tool_name",
376+
"arguments": {
377+
"n": 58013,
378+
"t": "value",
379+
},
380+
},
381+
},
382+
{
383+
"type": "toolCallRequest",
384+
"toolCallRequest": {
385+
"type": "function",
386+
"id": "114663648",
387+
"name": "another_example_tool_name",
388+
"arguments": {
389+
"n": 23,
390+
"t": "some other value",
391+
},
392+
},
393+
},
394+
]
395+
INPUT_TOOL_RESULTS: list[ToolCallResultDataDict] = [
396+
{
397+
"type": "toolCallResult",
398+
"toolCallId": "114663647",
399+
"content": "example tool call result",
400+
},
401+
{
402+
"type": "toolCallResult",
403+
"toolCallId": "114663648",
404+
"content": "another example tool call result",
405+
},
406+
]
274407

275408

276409
def test_get_history() -> None:
@@ -289,6 +422,8 @@ def test_get_history() -> None:
289422
chat.add_user_message("Avoid consecutive responses")
290423
chat.add_assistant_response(INPUT_FILE_HANDLE_DICT)
291424
chat.add_system_prompt(TextData(text="Structured text system prompt"))
425+
chat.add_assistant_response("Example tool call request", INPUT_TOOL_REQUESTS)
426+
chat.add_tool_results(INPUT_TOOL_RESULTS)
292427
assert chat._get_history_for_prediction() == EXPECTED_HISTORY
293428

294429

@@ -307,6 +442,19 @@ def test_add_entry() -> None:
307442
chat.add_entry("user", "Avoid consecutive responses")
308443
chat.add_entry("assistant", INPUT_FILE_HANDLE_DICT)
309444
chat.add_entry("system", TextData(text="Structured text system prompt"))
445+
tool_call_message_contents: AssistantMultiPartInput = [
446+
"Example tool call request",
447+
*INPUT_TOOL_REQUESTS,
448+
]
449+
chat.add_entry("assistant", tool_call_message_contents)
450+
chat.add_entry("tool", INPUT_TOOL_RESULTS)
451+
assert chat._get_history_for_prediction() == EXPECTED_HISTORY
452+
453+
454+
def test_append() -> None:
455+
chat = Chat()
456+
for message in INPUT_ENTRIES:
457+
chat.append(cast(AnyChatMessageDict, message))
310458
assert chat._get_history_for_prediction() == EXPECTED_HISTORY
311459

312460

0 commit comments

Comments
 (0)