Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/tool-use-multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def is_prime(n: int) -> bool:
return False
return True

model = lms.llm("qwen2.5-7b-instruct")
chat = lms.Chat()
model = lms.llm("qwen2.5-7b-instruct-1m")
model.act(
"Is the result of 12345 + 45668 a prime? Think step by step.",
[add, is_prime],
on_message=print,
on_message=chat.append,
)
print(chat)
6 changes: 4 additions & 2 deletions examples/tool-use.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ def multiply(a: float, b: float) -> float:
"""Given two numbers a and b. Returns the product of them."""
return a * b

model = lms.llm("qwen2.5-7b-instruct")
chat = lms.Chat()
model = lms.llm("qwen2.5-7b-instruct-1m")
model.act(
"What is the result of 12345 multiplied by 54321?",
[multiply],
on_message=print,
on_message=chat.append,
)
print(chat)
118 changes: 73 additions & 45 deletions src/lmstudio/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
ChatMessagePartFileDataDict as _FileHandleDict,
ChatMessagePartTextData as TextData,
ChatMessagePartTextDataDict as TextDataDict,
ChatMessagePartToolCallRequestData as _ToolCallRequestData,
ChatMessagePartToolCallRequestDataDict as _ToolCallRequestDataDict,
ChatMessagePartToolCallResultData as _ToolCallResultData,
ChatMessagePartToolCallResultDataDict as _ToolCallResultDataDict,
ChatMessagePartToolCallRequestData as ToolCallRequestData,
ChatMessagePartToolCallRequestDataDict as ToolCallRequestDataDict,
ChatMessagePartToolCallResultData as ToolCallResultData,
ChatMessagePartToolCallResultDataDict as ToolCallResultDataDict,
# Private until LM Studio file handle support stabilizes
# FileType,
FilesRpcUploadFileBase64Parameter,
# Private until user level tool call request management is defined
ToolCallRequest as _ToolCallRequest,
ToolCallRequest as ToolCallRequest,
FunctionToolCallRequestDict as ToolCallRequestDict,
)

__all__ = [
Expand All @@ -81,8 +81,8 @@
"TextData",
"TextDataDict",
# Private until user level tool call request management is defined
"_ToolCallRequest", # Other modules need this to be exported
"_ToolCallResultData", # Other modules need this to be exported
"ToolCallRequest",
"ToolCallResultData",
# "ToolCallRequest",
# "ToolCallResult",
"UserMessageContent",
Expand All @@ -109,11 +109,11 @@
SystemPromptContentDict = TextDataDict
UserMessageContent = TextData | _FileHandle
UserMessageContentDict = TextDataDict | _FileHandleDict
AssistantResponseContent = TextData | _FileHandle | _ToolCallRequestData
AssistantResponseContentDict = TextDataDict | _FileHandleDict | _ToolCallRequestDataDict
ChatMessageContent = TextData | _FileHandle | _ToolCallRequestData | _ToolCallResultData
AssistantResponseContent = TextData | _FileHandle
AssistantResponseContentDict = TextDataDict | _FileHandleDict
ChatMessageContent = TextData | _FileHandle | ToolCallRequestData | ToolCallResultData
ChatMessageContentDict = (
TextDataDict | _FileHandleDict | _ToolCallRequestData | _ToolCallResultDataDict
TextDataDict | _FileHandleDict | ToolCallRequestData | ToolCallResultDataDict
)


Expand All @@ -132,7 +132,13 @@ def _to_history_content(self) -> str:
AnyUserMessageInput = UserMessageInput | UserMessageMultiPartInput
AssistantResponseInput = str | AssistantResponseContent | AssistantResponseContentDict
AnyAssistantResponseInput = AssistantResponseInput | _ServerAssistantResponse
_ToolCallResultInput = _ToolCallResultData | _ToolCallResultDataDict
ToolCallRequestInput = (
ToolCallRequest
| ToolCallRequestDict
| ToolCallRequestData
| ToolCallRequestDataDict
)
ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict
ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict
ChatMessageMultiPartInput = UserMessageMultiPartInput
AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput
Expand Down Expand Up @@ -355,6 +361,21 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
if role == "user":
messages = cast(AnyUserMessageInput, content)
return self.add_user_message(messages)
# Assistant responses consist of a text response with zero or more tool requests
if role == "assistant":
if _is_chat_message_input(content):
response = cast(AssistantResponseInput, content)
return self.add_assistant_response(response)
try:
(response_content, *tool_request_contents) = content
except ValueError:
raise LMStudioValueError(
f"Unable to parse assistant response content: {content}"
) from None
response = cast(AssistantResponseInput, response_content)
tool_requests = cast(Iterable[ToolCallRequest], tool_request_contents)
return self.add_assistant_response(response, tool_requests)

# Other roles do not accept multi-part messages, so ensure there
# is exactly one content item given. We still accept iterables because
# that's how the wire format is defined and we want to accept that.
Expand All @@ -368,17 +389,13 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
except ValueError:
err_msg = f"{role!r} role does not support multi-part message content."
raise LMStudioValueError(err_msg) from None

match role:
case "system":
prompt = cast(SystemPromptInput, content_item)
result = self.add_system_prompt(prompt)
case "assistant":
response = cast(AssistantResponseInput, content_item)
result = self.add_assistant_response(response)
case "tool":
tool_result = cast(_ToolCallResultInput, content_item)
result = self._add_tool_result(tool_result)
tool_result = cast(ToolCallResultInput, content_item)
result = self.add_tool_result(tool_result)
case _:
raise LMStudioValueError(f"Unknown history role: {role}")
return result
Expand Down Expand Up @@ -556,11 +573,14 @@ def add_user_message(
@classmethod
def _parse_assistant_response(
cls, response: AnyAssistantResponseInput
) -> AssistantResponseContent:
) -> TextData | _FileHandle:
# Note: tool call requests are NOT accepted here, as they're expected
# to follow an initial text response
# It's not clear if file handles should be accepted as it's not obvious
# how client applications should process those (even though the API
# format nominally permits them here)
match response:
# Sadly, we can't use the union type aliases for matching,
# since the compiler needs visibility into every match target
case TextData() | _FileHandle() | _ToolCallRequestData():
case TextData() | _FileHandle():
return response
case str():
return TextData(text=response)
Expand All @@ -575,59 +595,67 @@ def _parse_assistant_response(
}:
# We accept snake_case here for consistency, but don't really expect it
return _FileHandle._from_any_dict(response)
case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}:
# We accept snake_case here for consistency, but don't really expect it
return _ToolCallRequestData._from_any_dict(response)
case _:
raise LMStudioValueError(
f"Unable to parse assistant response content: {response}"
)

@classmethod
def _parse_tool_call_request(
cls, request: ToolCallRequestInput
) -> ToolCallRequestData:
match request:
case ToolCallRequestData():
return request
case ToolCallRequest():
return ToolCallRequestData(tool_call_request=request)
case {"type": "toolCallRequest"}:
return ToolCallRequestData._from_any_dict(request)
case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}:
request_details = ToolCallRequest._from_any_dict(request)
return ToolCallRequestData(tool_call_request=request_details)
case _:
raise LMStudioValueError(
f"Unable to parse tool call request content: {request}"
)

@sdk_public_api()
def add_assistant_response(
self, response: AnyAssistantResponseInput
self,
response: AnyAssistantResponseInput,
tool_call_requests: Iterable[ToolCallRequestInput] = (),
) -> AssistantResponse:
"""Add a new 'assistant' response to the chat history."""
self._raise_if_consecutive(AssistantResponse.role, "assistant responses")
message_data = self._parse_assistant_response(response)
message = AssistantResponse(content=[message_data])
self._messages.append(message)
return message

def _add_assistant_tool_requests(
self, response: _ServerAssistantResponse, requests: Iterable[_ToolCallRequest]
) -> AssistantResponse:
self._raise_if_consecutive(AssistantResponse.role, "assistant responses")
message_text = self._parse_assistant_response(response)
request_parts = [
_ToolCallRequestData(tool_call_request=req) for req in requests
self._parse_tool_call_request(req) for req in tool_call_requests
]
message = AssistantResponse(content=[message_text, *request_parts])
self._messages.append(message)
return message

@classmethod
def _parse_tool_result(cls, result: _ToolCallResultInput) -> _ToolCallResultData:
def _parse_tool_result(cls, result: ToolCallResultInput) -> ToolCallResultData:
match result:
# Sadly, we can't use the union type aliases for matching,
# since the compiler needs visibility into every match target
case _ToolCallResultData():
case ToolCallResultData():
return result
case {"toolCallId": _, "content": _} | {"tool_call_id": _, "content": _}:
# We accept snake_case here for consistency, but don't really expect it
return _ToolCallResultData.from_dict(result)
return ToolCallResultData.from_dict(result)
case _:
raise LMStudioValueError(f"Unable to parse tool result: {result}")

def _add_tool_results(
self, results: Iterable[_ToolCallResultInput]
def add_tool_results(
self, results: Iterable[ToolCallResultInput]
) -> ToolResultMessage:
"""Add multiple tool results to the chat history as a single message."""
message_content = [self._parse_tool_result(result) for result in results]
message = ToolResultMessage(content=message_content)
self._messages.append(message)
return message

def _add_tool_result(self, result: _ToolCallResultInput) -> ToolResultMessage:
def add_tool_result(self, result: ToolCallResultInput) -> ToolResultMessage:
"""Add a new tool result to the chat history."""
# Consecutive tool result messages are allowed,
# so skip checking if the last message was a tool result
Expand Down
37 changes: 24 additions & 13 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sdk_public_type,
_truncate_traceback,
)
from .history import AssistantResponse, Chat, _ToolCallRequest, _ToolCallResultData
from .history import AssistantResponse, Chat, ToolCallRequest, ToolCallResultData
from .schemas import (
AnyLMStudioStruct,
DictObject,
Expand Down Expand Up @@ -1067,7 +1067,7 @@ class PredictionFragmentEvent(ChannelRxEvent[LlmPredictionFragment]):
pass


class PredictionToolCallEvent(ChannelRxEvent[_ToolCallRequest]):
class PredictionToolCallEvent(ChannelRxEvent[ToolCallRequest]):
pass


Expand Down Expand Up @@ -1114,7 +1114,7 @@ def __init__(
on_prompt_processing_progress: PromptProcessingCallback | None = None,
# The remaining options are only relevant for multi-round tool actions
handle_invalid_tool_request: Callable[
[LMStudioPredictionError, _ToolCallRequest | None], str
[LMStudioPredictionError, ToolCallRequest | None], str
]
| None = None,
llm_tools: LlmToolUseSettingToolArray | None = None,
Expand Down Expand Up @@ -1224,7 +1224,7 @@ def iter_message_events(
"toolCallRequest": tool_call_request,
}:
yield PredictionToolCallEvent(
_ToolCallRequest._from_api_dict(tool_call_request)
ToolCallRequest._from_api_dict(tool_call_request)
)
case {
"type": "toolCallGenerationFailed",
Expand Down Expand Up @@ -1267,10 +1267,17 @@ def handle_rx_event(self, event: PredictionRxEvent) -> None:
self._report_prompt_processing_progress(progress)
case PredictionFragmentEvent(_fragment):
if self._on_first_token is not None:
self._on_first_token()
self._logger.debug("Invoking on_first_token callback")
err_msg = f"First token callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
self._on_first_token()
self._on_first_token = None
if self._on_prediction_fragment is not None:
self._on_prediction_fragment(_fragment)
# TODO: Define an even-spammier-than-debug trace logging level for this
# self._logger.trace("Invoking on_prediction_fragment callback")
err_msg = f"Prediction fragment callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
self._on_prediction_fragment(_fragment)
pass
case PredictionToolCallEvent(_tool_call_request):
# Handled externally when iterating over events
Expand All @@ -1294,32 +1301,34 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
assert self._on_prompt_processing_progress is not None
err_msg = f"Prediction progress callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
self._logger.debug("Invoking on_prompt_processing_progress callback")
self._on_prompt_processing_progress(progress)

def _handle_invalid_tool_request(
self, err_msg: str, request: _ToolCallRequest | None = None
self, err_msg: str, request: ToolCallRequest | None = None
) -> str:
exc = LMStudioPredictionError(err_msg)
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
if _on_handle_invalid_tool_request is not None:
# Allow users to override the error message, or force an exception
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
err_msg = _on_handle_invalid_tool_request(exc, request)
if request is not None:
return err_msg
# We don't allow users to prevent the exception when there's no request
raise LMStudioPredictionError(err_msg)

def request_tool_call(
self, request: _ToolCallRequest
) -> Callable[[], _ToolCallResultData]:
self, request: ToolCallRequest
) -> Callable[[], ToolCallResultData]:
tool_name = request.name
tool_call_id = request.id
client_tool = self._client_tools.get(tool_name, None)
if client_tool is None:
err_msg = self._handle_invalid_tool_request(
f"Cannot find tool with name {tool_name}.", request
)
result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
return lambda: result
# Validate parameters against their specification
params_struct, implementation = client_tool
Expand All @@ -1330,14 +1339,14 @@ def request_tool_call(
err_msg = self._handle_invalid_tool_request(
f"Failed to parse arguments for tool {tool_name}: {exc}", request
)
result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
return lambda: result
kwds = to_builtins(parsed_kwds)

# Allow caller to schedule the tool call request for background execution
def _call_requested_tool() -> _ToolCallResultData:
def _call_requested_tool() -> ToolCallResultData:
call_result = implementation(**kwds)
return _ToolCallResultData(
return ToolCallResultData(
content=json.dumps(call_result), tool_call_id=tool_call_id
)

Expand Down Expand Up @@ -1980,6 +1989,8 @@ def __init__(self, model_identifier: str, session: TSession) -> None:
"""Initialize the LM Studio model reference."""
self.identifier = model_identifier
self._session = session
self._logger = logger = get_logger(type(self).__name__)
logger.update_context(model_identifier=model_identifier)

def __repr__(self) -> str:
return f"{type(self).__name__}(identifier={self.identifier!r})"
Expand Down
Loading
Loading