Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/lmstudio/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,13 @@ def _to_history_content(self) -> str:
| ToolCallRequestData
| ToolCallRequestDataDict
)
AssistantMultiPartInput = Iterable[AssistantResponseInput | ToolCallRequestInput]
ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict
ToolCallResultMultiPartInput = Iterable[ToolCallResultInput]
ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict
ChatMessageMultiPartInput = UserMessageMultiPartInput
ChatMessageMultiPartInput = (
UserMessageMultiPartInput | AssistantMultiPartInput | ToolCallResultMultiPartInput
)
AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput


Expand Down Expand Up @@ -251,9 +255,12 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
if role == "user":
messages = cast(AnyUserMessageInput, content)
return self.add_user_message(messages)
# Tool results accept multi-part content, so just forward it to that method
if role == "tool":
tool_results = cast(Iterable[ToolCallResultInput], content)
return self.add_tool_results(tool_results)
# Assistant responses consist of a text response with zero or more tool requests
if role == "assistant":
response: AssistantResponseInput
if _is_chat_message_input(content):
response = cast(AssistantResponseInput, content)
return self.add_assistant_response(response)
Expand All @@ -263,7 +270,7 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
raise LMStudioValueError(
f"Unable to parse assistant response content: {content}"
) from None
response = response_content
response = cast(AssistantResponseInput, response_content)
tool_requests = cast(Iterable[ToolCallRequest], tool_request_contents)
return self.add_assistant_response(response, tool_requests)

Expand All @@ -276,17 +283,14 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
content_item = content
else:
try:
(content_item,) = content
(content_item,) = cast(Iterable[ChatMessageInput], content)
except ValueError:
err_msg = f"{role!r} role does not support multi-part message content."
raise LMStudioValueError(err_msg) from None
match role:
case "system":
prompt = cast(SystemPromptInput, content_item)
result = self.add_system_prompt(prompt)
case "tool":
tool_result = cast(ToolCallResultInput, content_item)
result = self.add_tool_result(tool_result)
case _:
raise LMStudioValueError(f"Unknown history role: {role}")
return result
Expand Down
150 changes: 149 additions & 1 deletion tests/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from lmstudio.sdk_api import LMStudioOSError
from lmstudio.schemas import DictObject
from lmstudio.history import (
AnyChatMessageDict,
AnyChatMessageInput,
AssistantMultiPartInput,
Chat,
AnyChatMessageDict,
ChatHistoryData,
ChatHistoryDataDict,
LocalFileInput,
Expand All @@ -29,6 +30,10 @@
LlmPredictionStats,
PredictionResult,
)
from lmstudio._sdk_models import (
ToolCallRequestDataDict,
ToolCallResultDataDict,
)

from .support import IMAGE_FILEPATH, check_sdk_error

Expand Down Expand Up @@ -125,6 +130,51 @@
"role": "system",
"content": [{"type": "text", "text": "Structured text system prompt"}],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Example tool call request"},
{
"type": "toolCallRequest",
"toolCallRequest": {
"type": "function",
"id": "114663647",
"name": "example_tool_name",
"arguments": {
"n": 58013,
"t": "value",
},
},
},
{
"type": "toolCallRequest",
"toolCallRequest": {
"type": "function",
"id": "114663648",
"name": "another_example_tool_name",
"arguments": {
"n": 23,
"t": "some other value",
},
},
},
],
},
{
"role": "tool",
"content": [
{
"type": "toolCallResult",
"toolCallId": "114663647",
"content": "example tool call result",
},
{
"type": "toolCallResult",
"toolCallId": "114663648",
"content": "another example tool call result",
},
],
},
]

INPUT_HISTORY = {"messages": INPUT_ENTRIES}
Expand Down Expand Up @@ -214,6 +264,51 @@
"role": "system",
"content": [{"type": "text", "text": "Structured text system prompt"}],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Example tool call request"},
{
"type": "toolCallRequest",
"toolCallRequest": {
"type": "function",
"id": "114663647",
"name": "example_tool_name",
"arguments": {
"n": 58013,
"t": "value",
},
},
},
{
"type": "toolCallRequest",
"toolCallRequest": {
"type": "function",
"id": "114663648",
"name": "another_example_tool_name",
"arguments": {
"n": 23,
"t": "some other value",
},
},
},
],
},
{
"role": "tool",
"content": [
{
"type": "toolCallResult",
"toolCallId": "114663647",
"content": "example tool call result",
},
{
"type": "toolCallResult",
"toolCallId": "114663648",
"content": "another example tool call result",
},
],
},
]


Expand Down Expand Up @@ -271,6 +366,44 @@ def test_from_history_with_simple_text() -> None:
"sizeBytes": 100,
"fileType": "text/plain",
}
INPUT_TOOL_REQUESTS: list[ToolCallRequestDataDict] = [
{
"type": "toolCallRequest",
"toolCallRequest": {
"type": "function",
"id": "114663647",
"name": "example_tool_name",
"arguments": {
"n": 58013,
"t": "value",
},
},
},
{
"type": "toolCallRequest",
"toolCallRequest": {
"type": "function",
"id": "114663648",
"name": "another_example_tool_name",
"arguments": {
"n": 23,
"t": "some other value",
},
},
},
]
INPUT_TOOL_RESULTS: list[ToolCallResultDataDict] = [
{
"type": "toolCallResult",
"toolCallId": "114663647",
"content": "example tool call result",
},
{
"type": "toolCallResult",
"toolCallId": "114663648",
"content": "another example tool call result",
},
]


def test_get_history() -> None:
Expand All @@ -289,6 +422,8 @@ def test_get_history() -> None:
chat.add_user_message("Avoid consecutive responses")
chat.add_assistant_response(INPUT_FILE_HANDLE_DICT)
chat.add_system_prompt(TextData(text="Structured text system prompt"))
chat.add_assistant_response("Example tool call request", INPUT_TOOL_REQUESTS)
chat.add_tool_results(INPUT_TOOL_RESULTS)
assert chat._get_history_for_prediction() == EXPECTED_HISTORY


Expand All @@ -307,6 +442,19 @@ def test_add_entry() -> None:
chat.add_entry("user", "Avoid consecutive responses")
chat.add_entry("assistant", INPUT_FILE_HANDLE_DICT)
chat.add_entry("system", TextData(text="Structured text system prompt"))
tool_call_message_contents: AssistantMultiPartInput = [
"Example tool call request",
*INPUT_TOOL_REQUESTS,
]
chat.add_entry("assistant", tool_call_message_contents)
chat.add_entry("tool", INPUT_TOOL_RESULTS)
assert chat._get_history_for_prediction() == EXPECTED_HISTORY


def test_append() -> None:
chat = Chat()
for message in INPUT_ENTRIES:
chat.append(cast(AnyChatMessageDict, message))
assert chat._get_history_for_prediction() == EXPECTED_HISTORY


Expand Down