diff --git a/examples/tool-use-multiple.py b/examples/tool-use-multiple.py index 3a1c351..64f69eb 100644 --- a/examples/tool-use-multiple.py +++ b/examples/tool-use-multiple.py @@ -18,9 +18,11 @@ def is_prime(n: int) -> bool: return False return True -model = lms.llm("qwen2.5-7b-instruct") +chat = lms.Chat() +model = lms.llm("qwen2.5-7b-instruct-1m") model.act( "Is the result of 12345 + 45668 a prime? Think step by step.", [add, is_prime], - on_message=print, + on_message=chat.append, ) +print(chat) diff --git a/examples/tool-use.py b/examples/tool-use.py index f7a59d6..2c2bb36 100755 --- a/examples/tool-use.py +++ b/examples/tool-use.py @@ -7,9 +7,11 @@ def multiply(a: float, b: float) -> float: """Given two numbers a and b. Returns the product of them.""" return a * b -model = lms.llm("qwen2.5-7b-instruct") +chat = lms.Chat() +model = lms.llm("qwen2.5-7b-instruct-1m") model.act( "What is the result of 12345 multiplied by 54321?", [multiply], - on_message=print, + on_message=chat.append, ) +print(chat) diff --git a/src/lmstudio/history.py b/src/lmstudio/history.py index 7a1a6d7..cb77961 100644 --- a/src/lmstudio/history.py +++ b/src/lmstudio/history.py @@ -48,15 +48,15 @@ ChatMessagePartFileDataDict as _FileHandleDict, ChatMessagePartTextData as TextData, ChatMessagePartTextDataDict as TextDataDict, - ChatMessagePartToolCallRequestData as _ToolCallRequestData, - ChatMessagePartToolCallRequestDataDict as _ToolCallRequestDataDict, - ChatMessagePartToolCallResultData as _ToolCallResultData, - ChatMessagePartToolCallResultDataDict as _ToolCallResultDataDict, + ChatMessagePartToolCallRequestData as ToolCallRequestData, + ChatMessagePartToolCallRequestDataDict as ToolCallRequestDataDict, + ChatMessagePartToolCallResultData as ToolCallResultData, + ChatMessagePartToolCallResultDataDict as ToolCallResultDataDict, # Private until LM Studio file handle support stabilizes # FileType, FilesRpcUploadFileBase64Parameter, - # Private until user level tool call request management is defined - ToolCallRequest as _ToolCallRequest, + ToolCallRequest as ToolCallRequest, + FunctionToolCallRequestDict as ToolCallRequestDict, ) __all__ = [ @@ -81,8 +81,8 @@ "TextData", "TextDataDict", # Private until user level tool call request management is defined - "_ToolCallRequest", # Other modules need this to be exported - "_ToolCallResultData", # Other modules need this to be exported + "ToolCallRequest", + "ToolCallResultData", # "ToolCallRequest", # "ToolCallResult", "UserMessageContent", @@ -109,11 +109,11 @@ SystemPromptContentDict = TextDataDict UserMessageContent = TextData | _FileHandle UserMessageContentDict = TextDataDict | _FileHandleDict -AssistantResponseContent = TextData | _FileHandle | _ToolCallRequestData -AssistantResponseContentDict = TextDataDict | _FileHandleDict | _ToolCallRequestDataDict -ChatMessageContent = TextData | _FileHandle | _ToolCallRequestData | _ToolCallResultData +AssistantResponseContent = TextData | _FileHandle +AssistantResponseContentDict = TextDataDict | _FileHandleDict +ChatMessageContent = TextData | _FileHandle | ToolCallRequestData | ToolCallResultData ChatMessageContentDict = ( - TextDataDict | _FileHandleDict | _ToolCallRequestData | _ToolCallResultDataDict + TextDataDict | _FileHandleDict | ToolCallRequestData | ToolCallResultDataDict ) @@ -132,7 +132,13 @@ def _to_history_content(self) -> str: AnyUserMessageInput = UserMessageInput | UserMessageMultiPartInput AssistantResponseInput = str | AssistantResponseContent | AssistantResponseContentDict AnyAssistantResponseInput = AssistantResponseInput | _ServerAssistantResponse -_ToolCallResultInput = _ToolCallResultData | _ToolCallResultDataDict +ToolCallRequestInput = ( + ToolCallRequest + | ToolCallRequestDict + | ToolCallRequestData + | ToolCallRequestDataDict +) +ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict ChatMessageMultiPartInput = UserMessageMultiPartInput AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput @@ -355,6 +361,21 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage: if role == "user": messages = cast(AnyUserMessageInput, content) return self.add_user_message(messages) + # Assistant responses consist of a text response with zero or more tool requests + if role == "assistant": + if _is_chat_message_input(content): + response = cast(AssistantResponseInput, content) + return self.add_assistant_response(response) + try: + (response_content, *tool_request_contents) = content + except ValueError: + raise LMStudioValueError( + f"Unable to parse assistant response content: {content}" + ) from None + response = cast(AssistantResponseInput, response_content) + tool_requests = cast(Iterable[ToolCallRequest], tool_request_contents) + return self.add_assistant_response(response, tool_requests) + # Other roles do not accept multi-part messages, so ensure there # is exactly one content item given. We still accept iterables because # that's how the wire format is defined and we want to accept that. @@ -368,17 +389,13 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage: except ValueError: err_msg = f"{role!r} role does not support multi-part message content." raise LMStudioValueError(err_msg) from None - match role: case "system": prompt = cast(SystemPromptInput, content_item) result = self.add_system_prompt(prompt) - case "assistant": - response = cast(AssistantResponseInput, content_item) - result = self.add_assistant_response(response) case "tool": - tool_result = cast(_ToolCallResultInput, content_item) - result = self._add_tool_result(tool_result) + tool_result = cast(ToolCallResultInput, content_item) + result = self.add_tool_result(tool_result) case _: raise LMStudioValueError(f"Unknown history role: {role}") return result @@ -556,11 +573,14 @@ def add_user_message( @classmethod def _parse_assistant_response( cls, response: AnyAssistantResponseInput - ) -> AssistantResponseContent: + ) -> TextData | _FileHandle: + # Note: tool call requests are NOT accepted here, as they're expected + # to follow an initial text response + # It's not clear if file handles should be accepted as it's not obvious + # how client applications should process those (even though the API + # format nominally permits them here) match response: - # Sadly, we can't use the union type aliases for matching, - # since the compiler needs visibility into every match target - case TextData() | _FileHandle() | _ToolCallRequestData(): + case TextData() | _FileHandle(): return response case str(): return TextData(text=response) @@ -575,59 +595,67 @@ def _parse_assistant_response( }: # We accept snake_case here for consistency, but don't really expect it return _FileHandle._from_any_dict(response) - case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}: - # We accept snake_case here for consistency, but don't really expect it - return _ToolCallRequestData._from_any_dict(response) case _: raise LMStudioValueError( f"Unable to parse assistant response content: {response}" ) + @classmethod + def _parse_tool_call_request( + cls, request: ToolCallRequestInput + ) -> ToolCallRequestData: + match request: + case ToolCallRequestData(): + return request + case ToolCallRequest(): + return ToolCallRequestData(tool_call_request=request) + case {"type": "toolCallRequest"}: + return ToolCallRequestData._from_any_dict(request) + case {"toolCallRequest": [*_]} | {"tool_call_request": [*_]}: + request_details = ToolCallRequest._from_any_dict(request) + return ToolCallRequestData(tool_call_request=request_details) + case _: + raise LMStudioValueError( + f"Unable to parse tool call request content: {request}" + ) + @sdk_public_api() def add_assistant_response( - self, response: AnyAssistantResponseInput + self, + response: AnyAssistantResponseInput, + tool_call_requests: Iterable[ToolCallRequestInput] = (), ) -> AssistantResponse: """Add a new 'assistant' response to the chat history.""" - self._raise_if_consecutive(AssistantResponse.role, "assistant responses") - message_data = self._parse_assistant_response(response) - message = AssistantResponse(content=[message_data]) - self._messages.append(message) - return message - - def _add_assistant_tool_requests( - self, response: _ServerAssistantResponse, requests: Iterable[_ToolCallRequest] - ) -> AssistantResponse: self._raise_if_consecutive(AssistantResponse.role, "assistant responses") message_text = self._parse_assistant_response(response) request_parts = [ - _ToolCallRequestData(tool_call_request=req) for req in requests + self._parse_tool_call_request(req) for req in tool_call_requests ] message = AssistantResponse(content=[message_text, *request_parts]) self._messages.append(message) return message @classmethod - def _parse_tool_result(cls, result: _ToolCallResultInput) -> _ToolCallResultData: + def _parse_tool_result(cls, result: ToolCallResultInput) -> ToolCallResultData: match result: - # Sadly, we can't use the union type aliases for matching, - # since the compiler needs visibility into every match target - case _ToolCallResultData(): + case ToolCallResultData(): return result case {"toolCallId": _, "content": _} | {"tool_call_id": _, "content": _}: # We accept snake_case here for consistency, but don't really expect it - return _ToolCallResultData.from_dict(result) + return ToolCallResultData.from_dict(result) case _: raise LMStudioValueError(f"Unable to parse tool result: {result}") - def _add_tool_results( - self, results: Iterable[_ToolCallResultInput] + def add_tool_results( + self, results: Iterable[ToolCallResultInput] ) -> ToolResultMessage: + """Add multiple tool results to the chat history as a single message.""" message_content = [self._parse_tool_result(result) for result in results] message = ToolResultMessage(content=message_content) self._messages.append(message) return message - def _add_tool_result(self, result: _ToolCallResultInput) -> ToolResultMessage: + def add_tool_result(self, result: ToolCallResultInput) -> ToolResultMessage: """Add a new tool result to the chat history.""" # Consecutive tool result messages are allowed, # so skip checking if the last message was a tool result diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 448a8a2..82bba47 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -40,7 +40,7 @@ sdk_public_type, _truncate_traceback, ) -from .history import AssistantResponse, Chat, _ToolCallRequest, _ToolCallResultData +from .history import AssistantResponse, Chat, ToolCallRequest, ToolCallResultData from .schemas import ( AnyLMStudioStruct, DictObject, @@ -1067,7 +1067,7 @@ class PredictionFragmentEvent(ChannelRxEvent[LlmPredictionFragment]): pass -class PredictionToolCallEvent(ChannelRxEvent[_ToolCallRequest]): +class PredictionToolCallEvent(ChannelRxEvent[ToolCallRequest]): pass @@ -1114,7 +1114,7 @@ def __init__( on_prompt_processing_progress: PromptProcessingCallback | None = None, # The remaining options are only relevant for multi-round tool actions handle_invalid_tool_request: Callable[ - [LMStudioPredictionError, _ToolCallRequest | None], str + [LMStudioPredictionError, ToolCallRequest | None], str ] | None = None, llm_tools: LlmToolUseSettingToolArray | None = None, @@ -1224,7 +1224,7 @@ def iter_message_events( "toolCallRequest": tool_call_request, }: yield PredictionToolCallEvent( - _ToolCallRequest._from_api_dict(tool_call_request) + ToolCallRequest._from_api_dict(tool_call_request) ) case { "type": "toolCallGenerationFailed", @@ -1267,10 +1267,17 @@ def handle_rx_event(self, event: PredictionRxEvent) -> None: self._report_prompt_processing_progress(progress) case PredictionFragmentEvent(_fragment): if self._on_first_token is not None: - self._on_first_token() + self._logger.debug("Invoking on_first_token callback") + err_msg = f"First token callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + self._on_first_token() self._on_first_token = None if self._on_prediction_fragment is not None: - self._on_prediction_fragment(_fragment) + # TODO: Define an even-spammier-than-debug trace logging level for this + # self._logger.trace("Invoking on_prediction_fragment callback") + err_msg = f"Prediction fragment callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + self._on_prediction_fragment(_fragment) pass case PredictionToolCallEvent(_tool_call_request): # Handled externally when iterating over events @@ -1294,15 +1301,17 @@ def _report_prompt_processing_progress(self, progress: float) -> None: assert self._on_prompt_processing_progress is not None err_msg = f"Prediction progress callback failed for {self!r}" with sdk_callback_invocation(err_msg, self._logger): + self._logger.debug("Invoking on_prompt_processing_progress callback") self._on_prompt_processing_progress(progress) def _handle_invalid_tool_request( - self, err_msg: str, request: _ToolCallRequest | None = None + self, err_msg: str, request: ToolCallRequest | None = None ) -> str: exc = LMStudioPredictionError(err_msg) _on_handle_invalid_tool_request = self._on_handle_invalid_tool_request if _on_handle_invalid_tool_request is not None: # Allow users to override the error message, or force an exception + self._logger.debug("Invoking on_handle_invalid_tool_request callback") err_msg = _on_handle_invalid_tool_request(exc, request) if request is not None: return err_msg @@ -1310,8 +1319,8 @@ def _handle_invalid_tool_request( raise LMStudioPredictionError(err_msg) def request_tool_call( - self, request: _ToolCallRequest - ) -> Callable[[], _ToolCallResultData]: + self, request: ToolCallRequest + ) -> Callable[[], ToolCallResultData]: tool_name = request.name tool_call_id = request.id client_tool = self._client_tools.get(tool_name, None) @@ -1319,7 +1328,7 @@ def request_tool_call( err_msg = self._handle_invalid_tool_request( f"Cannot find tool with name {tool_name}.", request ) - result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id) + result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id) return lambda: result # Validate parameters against their specification params_struct, implementation = client_tool @@ -1330,14 +1339,14 @@ def request_tool_call( err_msg = self._handle_invalid_tool_request( f"Failed to parse arguments for tool {tool_name}: {exc}", request ) - result = _ToolCallResultData(content=err_msg, tool_call_id=tool_call_id) + result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id) return lambda: result kwds = to_builtins(parsed_kwds) # Allow caller to schedule the tool call request for background execution - def _call_requested_tool() -> _ToolCallResultData: + def _call_requested_tool() -> ToolCallResultData: call_result = implementation(**kwds) - return _ToolCallResultData( + return ToolCallResultData( content=json.dumps(call_result), tool_call_id=tool_call_id ) @@ -1980,6 +1989,8 @@ def __init__(self, model_identifier: str, session: TSession) -> None: """Initialize the LM Studio model reference.""" self.identifier = model_identifier self._session = session + self._logger = logger = get_logger(type(self).__name__) + logger.update_context(model_identifier=model_identifier) def __repr__(self) -> str: return f"{type(self).__name__}(identifier={self.identifier!r})" diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index a744b8f..ca7b8d3 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -40,7 +40,12 @@ # Synchronous API still uses an async websocket (just in a background thread) from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException -from .sdk_api import LMStudioRuntimeError, LMStudioValueError, sdk_public_api +from .sdk_api import ( + LMStudioRuntimeError, + LMStudioValueError, + sdk_callback_invocation, + sdk_public_api, +) from .schemas import AnyLMStudioStruct, DictObject, DictSchema, ModelSchema from .history import ( AssistantResponse, @@ -50,7 +55,7 @@ _FileHandle, _FileInputType, _LocalFileData, - _ToolCallRequest, + ToolCallRequest, ) from .json_api import ( ActResult, @@ -1560,7 +1565,7 @@ def act( on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None, on_prompt_processing_progress: Callable[[float, int], Any] | None = None, handle_invalid_tool_request: Callable[ - [LMStudioPredictionError, _ToolCallRequest | None], str + [LMStudioPredictionError, ToolCallRequest | None], str ] | None = None, ) -> ActResult: @@ -1619,8 +1624,13 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: # (or the maximum number of prediction rounds is reached) with ThreadPoolExecutor() as pool: for round_index in round_counter: + self._logger.debug( + "Starting .act() prediction round", round_index=round_index + ) if on_round_start is not None: - on_round_start(round_index) + err_msg = f"Round start callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_round_start(round_index) # Update the endpoint definition on each iteration in order to: # * update the chat history with the previous round result # * be able to disallow tool use when the rounds are limited @@ -1644,7 +1654,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: ) channel_cm = self._session._create_channel(endpoint) prediction_stream = PredictionStream(channel_cm, endpoint) - tool_call_requests: list[_ToolCallRequest] = [] + tool_call_requests: list[ToolCallRequest] = [] pending_tool_calls: list[SyncFuture[Any]] = [] for event in prediction_stream._iter_events(): if isinstance(event, PredictionToolCallEvent): @@ -1653,26 +1663,39 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: tool_call = endpoint.request_tool_call(tool_call_request) pending_tool_calls.append(pool.submit(tool_call)) prediction = prediction_stream.result() + self._logger.debug( + "Completed .act() prediction round", round_index=round_index + ) if on_prediction_completed: round_result = PredictionRoundResult.from_result( prediction, round_index ) - on_prediction_completed(round_result) + err_msg = f"Prediction completed callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_prediction_completed(round_result) if pending_tool_calls: tool_results = [ fut.result() for fut in as_completed(pending_tool_calls) ] - requests_message = agent_chat._add_assistant_tool_requests( + requests_message = agent_chat.add_assistant_response( prediction, tool_call_requests ) - results_message = agent_chat._add_tool_results(tool_results) + results_message = agent_chat.add_tool_results(tool_results) if on_message is not None: - on_message(requests_message) - on_message(results_message) + err_msg = f"Tool request message callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_message(requests_message) + err_msg = f"Tool result message callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_message(results_message) elif on_message is not None: - on_message(agent_chat.add_assistant_response(prediction)) + err_msg = f"Final response message callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_message(agent_chat.add_assistant_response(prediction)) if on_round_end is not None: - on_round_end(round_index) + err_msg = f"Round end callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_round_end(round_index) if not tool_call_requests: # No tool call requests -> we're done here break diff --git a/tests/test_inference.py b/tests/test_inference.py index c0544a8..73d05a0 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -10,7 +10,6 @@ from pytest_subtests import SubTests from lmstudio import ( - AnyChatMessage, AsyncClient, Chat, Client, @@ -170,9 +169,7 @@ def test_tool_using_agent(caplog: LogCap) -> None: with Client() as client: llm = client.llm.model(model_id) chat = Chat() - chat.add_user_message( - "What is the sum of 123 and the largest prime smaller than 100?" - ) + chat.add_user_message("What is the sum of 123 and 3210?") tools = [ADDITION_TOOL_SPEC] # Ensure ignoring the round index passes static type checks predictions: list[PredictionResult[str]] = [] @@ -180,7 +177,7 @@ def test_tool_using_agent(caplog: LogCap) -> None: act_result = llm.act(chat, tools, on_prediction_completed=predictions.append) assert len(predictions) > 1 assert act_result.rounds == len(predictions) - assert "220" in predictions[-1].content + assert "3333" in predictions[-1].content for _logger_name, log_level, message in caplog.record_tuples: if log_level != logging.INFO: @@ -190,7 +187,7 @@ def test_tool_using_agent(caplog: LogCap) -> None: else: assert False, "Failed to find tool call logging entry" assert "123" in message - assert "97" in message + assert "3210" in message @pytest.mark.lmstudio @@ -202,14 +199,11 @@ def test_tool_using_agent_callbacks(caplog: LogCap) -> None: with Client() as client: llm = client.llm.model(model_id) chat = Chat() - chat.add_user_message( - "What is the sum of 123 and the largest prime smaller than 100?" - ) + chat.add_user_message("What is the sum of 123 and 3210?") tools = [ADDITION_TOOL_SPEC] round_starts: list[int] = [] round_ends: list[int] = [] first_tokens: list[int] = [] - messages: list[AnyChatMessage] = [] predictions: list[PredictionRoundResult] = [] fragments: list[LlmPredictionFragment] = [] last_fragment_round_index = 0 @@ -227,7 +221,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: tools, on_first_token=first_tokens.append, on_prediction_fragment=_append_fragment, - on_message=messages.append, + on_message=chat.append, on_round_start=round_starts.append, on_round_end=round_ends.append, on_prediction_completed=predictions.append, @@ -236,8 +230,12 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: sequential_round_indices = list(range(num_rounds)) assert num_rounds > 1 assert [p.round_index for p in predictions] == sequential_round_indices - assert first_tokens == sequential_round_indices assert round_starts == sequential_round_indices assert round_ends == sequential_round_indices - assert len(messages) == 2 * num_rounds - 1 # No tool results in last round + expected_token_indices = [p.round_index for p in predictions if p.content] + assert first_tokens == expected_token_indices assert last_fragment_round_index == num_rounds - 1 + assert len(chat._messages) == 2 * num_rounds # No tool results in last round + + cloned_chat = chat.copy() + assert cloned_chat._messages == chat._messages