diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index e758a07..25926a1 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1131,7 +1131,7 @@ def __init__( on_prompt_processing_progress: PromptProcessingCallback | None = None, # The remaining options are only relevant for multi-round tool actions handle_invalid_tool_request: Callable[ - [LMStudioPredictionError, ToolCallRequest | None], str + [LMStudioPredictionError, ToolCallRequest | None], str | None ] | None = None, llm_tools: LlmToolUseSettingToolArray | None = None, @@ -1336,12 +1336,14 @@ def _report_prompt_processing_progress(self, progress: float) -> None: def _handle_invalid_tool_request( self, err_msg: str, request: ToolCallRequest | None = None ) -> str: - exc = LMStudioPredictionError(err_msg) _on_handle_invalid_tool_request = self._on_handle_invalid_tool_request if _on_handle_invalid_tool_request is not None: # Allow users to override the error message, or force an exception self._logger.debug("Invoking on_handle_invalid_tool_request callback") - err_msg = _on_handle_invalid_tool_request(exc, request) + exc = LMStudioPredictionError(err_msg) + user_err_msg = _on_handle_invalid_tool_request(exc, request) + if user_err_msg is not None: + err_msg = user_err_msg if request is not None: return err_msg # We don't allow users to prevent the exception when there's no request diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index 720664b..735c029 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -1499,7 +1499,7 @@ def act( on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None, on_prompt_processing_progress: Callable[[float, int], Any] | None = None, handle_invalid_tool_request: Callable[ - [LMStudioPredictionError, ToolCallRequest | None], str + [LMStudioPredictionError, ToolCallRequest | None], str | None ] | None = None, ) -> ActResult: diff --git a/tests/test_inference.py b/tests/test_inference.py index a523b9b..71027fe 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -15,9 +15,11 @@ Client, LlmPredictionConfig, LlmPredictionFragment, + LMStudioPredictionError, LMStudioValueError, PredictionResult, PredictionRoundResult, + ToolCallRequest, ToolFunctionDef, ToolFunctionDefDict, ) @@ -162,7 +164,7 @@ def test_duplicate_tool_names_rejected() -> None: @pytest.mark.lmstudio def test_tool_using_agent(caplog: LogCap) -> None: - # This is currently a sync-only API (it will be refactored after 1.0.0) + # This is currently a sync-only API (it will be refactored in a future release) caplog.set_level(logging.DEBUG) model_id = TOOL_LLM_ID @@ -192,7 +194,7 @@ def test_tool_using_agent(caplog: LogCap) -> None: @pytest.mark.lmstudio def test_tool_using_agent_callbacks(caplog: LogCap) -> None: - # This is currently a sync-only API (it will be refactored after 1.0.0) + # This is currently a sync-only API (it will be refactored in a future release) caplog.set_level(logging.DEBUG) model_id = TOOL_LLM_ID @@ -241,3 +243,49 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: cloned_chat = chat.copy() assert cloned_chat._messages == chat._messages + + +def divide(numerator: float, denominator: float) -> float | str: + """Divide the given numerator by the given denominator. Return the result.""" + try: + return numerator / denominator + except Exception as exc: + # TODO: Perform this exception-to-response-string translation implicitly + return f"Unhandled Python exception: {exc!r}" + + +@pytest.mark.lmstudio +def test_tool_using_agent_error_handling(caplog: LogCap) -> None: + # This is currently a sync-only API (it will be refactored in a future release) + + caplog.set_level(logging.DEBUG) + model_id = TOOL_LLM_ID + with Client() as client: + llm = client.llm.model(model_id) + chat = Chat() + chat.add_user_message( + "Attempt to divide 1 by 0 using the tool. Explain the result." + ) + tools = [divide] + predictions: list[PredictionRoundResult] = [] + invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = [] + + def _handle_invalid_request( + exc: LMStudioPredictionError, request: ToolCallRequest | None + ) -> None: + if request is not None: + invalid_requests.append((exc, request)) + + act_result = llm.act( + chat, + tools, + handle_invalid_tool_request=_handle_invalid_request, + on_prediction_completed=predictions.append, + ) + assert len(predictions) > 1 + assert act_result.rounds == len(predictions) + # Test case is currently suppressing exceptions inside the tool call + assert invalid_requests == [] + # If the content checks prove flaky in practice, they can be dropped + assert "divide" in predictions[-1].content + assert "zero" in predictions[-1].content