4848 ChatMessagePartFileDataDict as _FileHandleDict ,
4949 ChatMessagePartTextData as TextData ,
5050 ChatMessagePartTextDataDict as TextDataDict ,
51- ChatMessagePartToolCallRequestData as _ToolCallRequestData ,
52- ChatMessagePartToolCallRequestDataDict as _ToolCallRequestDataDict ,
53- ChatMessagePartToolCallResultData as _ToolCallResultData ,
54- ChatMessagePartToolCallResultDataDict as _ToolCallResultDataDict ,
51+ ChatMessagePartToolCallRequestData as ToolCallRequestData ,
52+ ChatMessagePartToolCallRequestDataDict as ToolCallRequestDataDict ,
53+ ChatMessagePartToolCallResultData as ToolCallResultData ,
54+ ChatMessagePartToolCallResultDataDict as ToolCallResultDataDict ,
5555 # Private until LM Studio file handle support stabilizes
5656 # FileType,
5757 FilesRpcUploadFileBase64Parameter ,
58- # Private until user level tool call request management is defined
59- ToolCallRequest as _ToolCallRequest ,
58+ ToolCallRequest as ToolCallRequest ,
59+ FunctionToolCallRequestDict as ToolCallRequestDict ,
6060)
6161
6262__all__ = [
8181 "TextData" ,
8282 "TextDataDict" ,
8383 # Private until user level tool call request management is defined
84- "_ToolCallRequest" , # Other modules need this to be exported
85- "_ToolCallResultData" , # Other modules need this to be exported
84+ "ToolCallRequest" ,
85+ "ToolCallResultData" ,
8686 # "ToolCallRequest",
8787 # "ToolCallResult",
8888 "UserMessageContent" ,
109109SystemPromptContentDict = TextDataDict
110110UserMessageContent = TextData | _FileHandle
111111UserMessageContentDict = TextDataDict | _FileHandleDict
112- AssistantResponseContent = TextData | _FileHandle | _ToolCallRequestData
113- AssistantResponseContentDict = TextDataDict | _FileHandleDict | _ToolCallRequestDataDict
114- ChatMessageContent = TextData | _FileHandle | _ToolCallRequestData | _ToolCallResultData
112+ AssistantResponseContent = TextData | _FileHandle
113+ AssistantResponseContentDict = TextDataDict | _FileHandleDict
114+ ChatMessageContent = TextData | _FileHandle | ToolCallRequestData | ToolCallResultData
115115ChatMessageContentDict = (
116- TextDataDict | _FileHandleDict | _ToolCallRequestData | _ToolCallResultDataDict
116+ TextDataDict | _FileHandleDict | ToolCallRequestData | ToolCallResultDataDict
117117)
118118
119119
@@ -132,7 +132,13 @@ def _to_history_content(self) -> str:
132132AnyUserMessageInput = UserMessageInput | UserMessageMultiPartInput
133133AssistantResponseInput = str | AssistantResponseContent | AssistantResponseContentDict
134134AnyAssistantResponseInput = AssistantResponseInput | _ServerAssistantResponse
135- _ToolCallResultInput = _ToolCallResultData | _ToolCallResultDataDict
135+ ToolCallRequestInput = (
136+ ToolCallRequest
137+ | ToolCallRequestDict
138+ | ToolCallRequestData
139+ | ToolCallRequestDataDict
140+ )
141+ ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict
136142ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict
137143ChatMessageMultiPartInput = UserMessageMultiPartInput
138144AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput
@@ -355,6 +361,21 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
355361 if role == "user" :
356362 messages = cast (AnyUserMessageInput , content )
357363 return self .add_user_message (messages )
364+ # Assistant responses consist of a text response with zero or more tool requests
365+ if role == "assistant" :
366+ if _is_chat_message_input (content ):
367+ response = cast (AssistantResponseInput , content )
368+ return self .add_assistant_response (response )
369+ try :
370+ (response_content , * tool_request_contents ) = content
371+ except ValueError :
372+ raise LMStudioValueError (
373+ f"Unable to parse assistant response content: { content } "
374+ ) from None
375+ response = cast (AssistantResponseInput , response_content )
376+ tool_requests = cast (Iterable [ToolCallRequest ], tool_request_contents )
377+ return self .add_assistant_response (response , tool_requests )
378+
358379 # Other roles do not accept multi-part messages, so ensure there
359380 # is exactly one content item given. We still accept iterables because
360381 # that's how the wire format is defined and we want to accept that.
@@ -368,17 +389,13 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage:
368389 except ValueError :
369390 err_msg = f"{ role !r} role does not support multi-part message content."
370391 raise LMStudioValueError (err_msg ) from None
371-
372392 match role :
373393 case "system" :
374394 prompt = cast (SystemPromptInput , content_item )
375395 result = self .add_system_prompt (prompt )
376- case "assistant" :
377- response = cast (AssistantResponseInput , content_item )
378- result = self .add_assistant_response (response )
379396 case "tool" :
380- tool_result = cast (_ToolCallResultInput , content_item )
381- result = self ._add_tool_result (tool_result )
397+ tool_result = cast (ToolCallResultInput , content_item )
398+ result = self .add_tool_result (tool_result )
382399 case _:
383400 raise LMStudioValueError (f"Unknown history role: { role } " )
384401 return result
@@ -556,11 +573,14 @@ def add_user_message(
556573 @classmethod
557574 def _parse_assistant_response (
558575 cls , response : AnyAssistantResponseInput
559- ) -> AssistantResponseContent :
576+ ) -> TextData | _FileHandle :
577+ # Note: tool call requests are NOT accepted here, as they're expected
578+ # to follow an initial text response
579+ # It's not clear if file handles should be accepted as it's not obvious
580+ # how client applications should process those (even though the API
581+ # format nominally permits them here)
560582 match response :
561- # Sadly, we can't use the union type aliases for matching,
562- # since the compiler needs visibility into every match target
563- case TextData () | _FileHandle () | _ToolCallRequestData ():
583+ case TextData () | _FileHandle ():
564584 return response
565585 case str ():
566586 return TextData (text = response )
@@ -575,59 +595,67 @@ def _parse_assistant_response(
575595 }:
576596 # We accept snake_case here for consistency, but don't really expect it
577597 return _FileHandle ._from_any_dict (response )
578- case {"toolCallRequest" : [* _]} | {"tool_call_request" : [* _]}:
579- # We accept snake_case here for consistency, but don't really expect it
580- return _ToolCallRequestData ._from_any_dict (response )
581598 case _:
582599 raise LMStudioValueError (
583600 f"Unable to parse assistant response content: { response } "
584601 )
585602
603+ @classmethod
604+ def _parse_tool_call_request (
605+ cls , request : ToolCallRequestInput
606+ ) -> ToolCallRequestData :
607+ match request :
608+ case ToolCallRequestData ():
609+ return request
610+ case ToolCallRequest ():
611+ return ToolCallRequestData (tool_call_request = request )
612+ case {"type" : "toolCallRequest" }:
613+ return ToolCallRequestData ._from_any_dict (request )
614+ case {"toolCallRequest" : [* _]} | {"tool_call_request" : [* _]}:
615+ request_details = ToolCallRequest ._from_any_dict (request )
616+ return ToolCallRequestData (tool_call_request = request_details )
617+ case _:
618+ raise LMStudioValueError (
619+ f"Unable to parse tool call request content: { request } "
620+ )
621+
586622 @sdk_public_api ()
587623 def add_assistant_response (
588- self , response : AnyAssistantResponseInput
624+ self ,
625+ response : AnyAssistantResponseInput ,
626+ tool_call_requests : Iterable [ToolCallRequestInput ] = (),
589627 ) -> AssistantResponse :
590628 """Add a new 'assistant' response to the chat history."""
591- self ._raise_if_consecutive (AssistantResponse .role , "assistant responses" )
592- message_data = self ._parse_assistant_response (response )
593- message = AssistantResponse (content = [message_data ])
594- self ._messages .append (message )
595- return message
596-
597- def _add_assistant_tool_requests (
598- self , response : _ServerAssistantResponse , requests : Iterable [_ToolCallRequest ]
599- ) -> AssistantResponse :
600629 self ._raise_if_consecutive (AssistantResponse .role , "assistant responses" )
601630 message_text = self ._parse_assistant_response (response )
602631 request_parts = [
603- _ToolCallRequestData ( tool_call_request = req ) for req in requests
632+ self . _parse_tool_call_request ( req ) for req in tool_call_requests
604633 ]
605634 message = AssistantResponse (content = [message_text , * request_parts ])
606635 self ._messages .append (message )
607636 return message
608637
609638 @classmethod
610- def _parse_tool_result (cls , result : _ToolCallResultInput ) -> _ToolCallResultData :
639+ def _parse_tool_result (cls , result : ToolCallResultInput ) -> ToolCallResultData :
611640 match result :
612- # Sadly, we can't use the union type aliases for matching,
613- # since the compiler needs visibility into every match target
614- case _ToolCallResultData ():
641+ case ToolCallResultData ():
615642 return result
616643 case {"toolCallId" : _, "content" : _} | {"tool_call_id" : _, "content" : _}:
617644 # We accept snake_case here for consistency, but don't really expect it
618- return _ToolCallResultData .from_dict (result )
645+ return ToolCallResultData .from_dict (result )
619646 case _:
620647 raise LMStudioValueError (f"Unable to parse tool result: { result } " )
621648
622- def _add_tool_results (
623- self , results : Iterable [_ToolCallResultInput ]
649+ def add_tool_results (
650+ self , results : Iterable [ToolCallResultInput ]
624651 ) -> ToolResultMessage :
652+ """Add multiple tool results to the chat history as a single message."""
625653 message_content = [self ._parse_tool_result (result ) for result in results ]
626654 message = ToolResultMessage (content = message_content )
627655 self ._messages .append (message )
628656 return message
629657
630- def _add_tool_result (self , result : _ToolCallResultInput ) -> ToolResultMessage :
658+ def add_tool_result (self , result : ToolCallResultInput ) -> ToolResultMessage :
631659 """Add a new tool result to the chat history."""
632660 # Consecutive tool result messages are allowed,
633661 # so skip checking if the last message was a tool result
0 commit comments