Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,21 +1334,35 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
self._on_prompt_processing_progress(progress)

def _handle_invalid_tool_request(
self, err_msg: str, request: ToolCallRequest | None = None
self,
err_msg: str,
request: ToolCallRequest | None = None,
*,
exc: Exception | None = None,
) -> str:
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
if _on_handle_invalid_tool_request is not None:
# Allow users to override the error message, or force an exception
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
exc = LMStudioPredictionError(err_msg)
user_err_msg = _on_handle_invalid_tool_request(exc, request)
callback_exc = LMStudioPredictionError(err_msg)
if exc is not None:
callback_exc.__cause__ = exc
user_err_msg = _on_handle_invalid_tool_request(callback_exc, request)
if user_err_msg is not None:
err_msg = user_err_msg
if request is not None:
return err_msg
# We don't allow users to prevent the exception when there's no request
raise LMStudioPredictionError(err_msg)

def _handle_failed_tool_request(
self, exc: Exception, request: ToolCallRequest
) -> ToolCallResultData:
err_msg = self._handle_invalid_tool_request(
f"Unhandled Python exception: {exc!r}", request, exc=exc
)
return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id)

def request_tool_call(
self, request: ToolCallRequest
) -> Callable[[], ToolCallResultData]:
Expand Down
20 changes: 17 additions & 3 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,13 +1591,13 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
channel_cm = self._session._create_channel(endpoint)
prediction_stream = PredictionStream(channel_cm, endpoint)
tool_call_requests: list[ToolCallRequest] = []
pending_tool_calls: list[SyncFuture[Any]] = []
pending_tool_calls: dict[SyncFuture[Any], ToolCallRequest] = {}
for event in prediction_stream._iter_events():
if isinstance(event, PredictionToolCallEvent):
tool_call_request = event.arg
tool_call_requests.append(tool_call_request)
tool_call = endpoint.request_tool_call(tool_call_request)
pending_tool_calls.append(pool.submit(tool_call))
pending_tool_calls[pool.submit(tool_call)] = tool_call_request
prediction = prediction_stream.result()
self._logger.debug(
"Completed .act() prediction round", round_index=round_index
Expand All @@ -1610,8 +1610,22 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
with sdk_callback_invocation(err_msg, self._logger):
on_prediction_completed(round_result)
if pending_tool_calls:

def _finish_tool_call(fut: SyncFuture[Any]) -> Any:
exc = fut.exception()
if exc is not None:
if not isinstance(exc, Exception):
# Don't allow base exceptions to be suppressed
raise exc
failed_request = pending_tool_calls[fut]
return endpoint._handle_failed_tool_request(
exc, failed_request
)
return fut.result()

tool_results = [
fut.result() for fut in as_completed(pending_tool_calls)
_finish_tool_call(fut)
for fut in as_completed(pending_tool_calls)
]
requests_message = agent_chat.add_assistant_response(
prediction, tool_call_requests
Expand Down
17 changes: 8 additions & 9 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:

def divide(numerator: float, denominator: float) -> float | str:
"""Divide the given numerator by the given denominator. Return the result."""
try:
return numerator / denominator
except Exception as exc:
# TODO: Perform this exception-to-response-string translation implicitly
return f"Unhandled Python exception: {exc!r}"
return numerator / denominator


@pytest.mark.lmstudio
Expand All @@ -268,13 +264,13 @@ def test_tool_using_agent_error_handling(caplog: LogCap) -> None:
)
tools = [divide]
predictions: list[PredictionRoundResult] = []
invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = []
request_failures: list[LMStudioPredictionError] = []

def _handle_invalid_request(
exc: LMStudioPredictionError, request: ToolCallRequest | None
) -> None:
if request is not None:
invalid_requests.append((exc, request))
request_failures.append(exc)

act_result = llm.act(
chat,
Expand All @@ -284,8 +280,11 @@ def _handle_invalid_request(
)
assert len(predictions) > 1
assert act_result.rounds == len(predictions)
# Test case is currently suppressing exceptions inside the tool call
assert invalid_requests == []
# Ensure the tool call failure was reported to the user callback
assert len(request_failures) == 1
tool_failure_exc = request_failures[0]
assert isinstance(tool_failure_exc, LMStudioPredictionError)
assert isinstance(tool_failure_exc.__cause__, ZeroDivisionError)
# If the content checks prove flaky in practice, they can be dropped
assert "divide" in predictions[-1].content
assert "zero" in predictions[-1].content