diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 25926a1..f5d8d28 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1334,14 +1334,20 @@ def _report_prompt_processing_progress(self, progress: float) -> None: self._on_prompt_processing_progress(progress) def _handle_invalid_tool_request( - self, err_msg: str, request: ToolCallRequest | None = None + self, + err_msg: str, + request: ToolCallRequest | None = None, + *, + exc: Exception | None = None, ) -> str: _on_handle_invalid_tool_request = self._on_handle_invalid_tool_request if _on_handle_invalid_tool_request is not None: # Allow users to override the error message, or force an exception self._logger.debug("Invoking on_handle_invalid_tool_request callback") - exc = LMStudioPredictionError(err_msg) - user_err_msg = _on_handle_invalid_tool_request(exc, request) + callback_exc = LMStudioPredictionError(err_msg) + if exc is not None: + callback_exc.__cause__ = exc + user_err_msg = _on_handle_invalid_tool_request(callback_exc, request) if user_err_msg is not None: err_msg = user_err_msg if request is not None: @@ -1349,6 +1355,14 @@ def _handle_invalid_tool_request( # We don't allow users to prevent the exception when there's no request raise LMStudioPredictionError(err_msg) + def _handle_failed_tool_request( + self, exc: Exception, request: ToolCallRequest + ) -> ToolCallResultData: + err_msg = self._handle_invalid_tool_request( + f"Unhandled Python exception: {exc!r}", request, exc=exc + ) + return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id) + def request_tool_call( self, request: ToolCallRequest ) -> Callable[[], ToolCallResultData]: diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index 735c029..b41e9b9 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -1591,13 +1591,13 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: channel_cm = self._session._create_channel(endpoint) prediction_stream = PredictionStream(channel_cm, endpoint) tool_call_requests: list[ToolCallRequest] = [] - pending_tool_calls: list[SyncFuture[Any]] = [] + pending_tool_calls: dict[SyncFuture[Any], ToolCallRequest] = {} for event in prediction_stream._iter_events(): if isinstance(event, PredictionToolCallEvent): tool_call_request = event.arg tool_call_requests.append(tool_call_request) tool_call = endpoint.request_tool_call(tool_call_request) - pending_tool_calls.append(pool.submit(tool_call)) + pending_tool_calls[pool.submit(tool_call)] = tool_call_request prediction = prediction_stream.result() self._logger.debug( "Completed .act() prediction round", round_index=round_index @@ -1610,8 +1610,22 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: with sdk_callback_invocation(err_msg, self._logger): on_prediction_completed(round_result) if pending_tool_calls: + + def _finish_tool_call(fut: SyncFuture[Any]) -> Any: + exc = fut.exception() + if exc is not None: + if not isinstance(exc, Exception): + # Don't allow base exceptions to be suppressed + raise exc + failed_request = pending_tool_calls[fut] + return endpoint._handle_failed_tool_request( + exc, failed_request + ) + return fut.result() + tool_results = [ - fut.result() for fut in as_completed(pending_tool_calls) + _finish_tool_call(fut) + for fut in as_completed(pending_tool_calls) ] requests_message = agent_chat.add_assistant_response( prediction, tool_call_requests diff --git a/tests/test_inference.py b/tests/test_inference.py index 71027fe..60cd91e 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -247,11 +247,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: def divide(numerator: float, denominator: float) -> float | str: """Divide the given numerator by the given denominator. Return the result.""" - try: - return numerator / denominator - except Exception as exc: - # TODO: Perform this exception-to-response-string translation implicitly - return f"Unhandled Python exception: {exc!r}" + return numerator / denominator @pytest.mark.lmstudio @@ -268,13 +264,13 @@ def test_tool_using_agent_error_handling(caplog: LogCap) -> None: ) tools = [divide] predictions: list[PredictionRoundResult] = [] - invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = [] + request_failures: list[LMStudioPredictionError] = [] def _handle_invalid_request( exc: LMStudioPredictionError, request: ToolCallRequest | None ) -> None: if request is not None: - invalid_requests.append((exc, request)) + request_failures.append(exc) act_result = llm.act( chat, @@ -284,8 +280,11 @@ def _handle_invalid_request( ) assert len(predictions) > 1 assert act_result.rounds == len(predictions) - # Test case is currently suppressing exceptions inside the tool call - assert invalid_requests == [] + # Ensure the tool call failure was reported to the user callback + assert len(request_failures) == 1 + tool_failure_exc = request_failures[0] + assert isinstance(tool_failure_exc, LMStudioPredictionError) + assert isinstance(tool_failure_exc.__cause__, ZeroDivisionError) # If the content checks prove flaky in practice, they can be dropped assert "divide" in predictions[-1].content assert "zero" in predictions[-1].content