Skip to content

Commit 9ea1f60

Browse files
committed
Add initial tool error handling test case
The SDK doesn't currently manage unhandled exceptions in tool calls. Add an initial test case that suppresses the exception inside the tool call. A subsequent PR will update the test case to provide this behaviour as the default behaviour.
1 parent 05a58ce commit 9ea1f60

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

src/lmstudio/json_api.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ def __init__(
11311131
on_prompt_processing_progress: PromptProcessingCallback | None = None,
11321132
# The remaining options are only relevant for multi-round tool actions
11331133
handle_invalid_tool_request: Callable[
1134-
[LMStudioPredictionError, ToolCallRequest | None], str
1134+
[LMStudioPredictionError, ToolCallRequest | None], str | None
11351135
]
11361136
| None = None,
11371137
llm_tools: LlmToolUseSettingToolArray | None = None,
@@ -1336,12 +1336,14 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
13361336
def _handle_invalid_tool_request(
13371337
self, err_msg: str, request: ToolCallRequest | None = None
13381338
) -> str:
1339-
exc = LMStudioPredictionError(err_msg)
13401339
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
13411340
if _on_handle_invalid_tool_request is not None:
13421341
# Allow users to override the error message, or force an exception
13431342
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
1344-
err_msg = _on_handle_invalid_tool_request(exc, request)
1343+
exc = LMStudioPredictionError(err_msg)
1344+
user_err_msg = _on_handle_invalid_tool_request(exc, request)
1345+
if user_err_msg is not None:
1346+
err_msg = user_err_msg
13451347
if request is not None:
13461348
return err_msg
13471349
# We don't allow users to prevent the exception when there's no request

src/lmstudio/sync_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1499,7 +1499,7 @@ def act(
14991499
on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None,
15001500
on_prompt_processing_progress: Callable[[float, int], Any] | None = None,
15011501
handle_invalid_tool_request: Callable[
1502-
[LMStudioPredictionError, ToolCallRequest | None], str
1502+
[LMStudioPredictionError, ToolCallRequest | None], str | None
15031503
]
15041504
| None = None,
15051505
) -> ActResult:

tests/test_inference.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
Client,
1616
LlmPredictionConfig,
1717
LlmPredictionFragment,
18+
LMStudioPredictionError,
1819
LMStudioValueError,
1920
PredictionResult,
2021
PredictionRoundResult,
22+
ToolCallRequest,
2123
ToolFunctionDef,
2224
ToolFunctionDefDict,
2325
)
@@ -162,7 +164,7 @@ def test_duplicate_tool_names_rejected() -> None:
162164

163165
@pytest.mark.lmstudio
164166
def test_tool_using_agent(caplog: LogCap) -> None:
165-
# This is currently a sync-only API (it will be refactored after 1.0.0)
167+
# This is currently a sync-only API (it will be refactored in a future release)
166168

167169
caplog.set_level(logging.DEBUG)
168170
model_id = TOOL_LLM_ID
@@ -192,7 +194,7 @@ def test_tool_using_agent(caplog: LogCap) -> None:
192194

193195
@pytest.mark.lmstudio
194196
def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
195-
# This is currently a sync-only API (it will be refactored after 1.0.0)
197+
# This is currently a sync-only API (it will be refactored in a future release)
196198

197199
caplog.set_level(logging.DEBUG)
198200
model_id = TOOL_LLM_ID
@@ -241,3 +243,49 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
241243

242244
cloned_chat = chat.copy()
243245
assert cloned_chat._messages == chat._messages
246+
247+
248+
def divide(numerator: float, denominator: float) -> float | str:
249+
"""Divide the given numerator by the given denominator. Return the result."""
250+
try:
251+
return numerator / denominator
252+
except Exception as exc:
253+
# TODO: Perform this exception-to-response-string translation implicitly
254+
return f"Unhandled Python exception: {exc!r}"
255+
256+
257+
@pytest.mark.lmstudio
258+
def test_tool_using_agent_error_handling(caplog: LogCap) -> None:
259+
# This is currently a sync-only API (it will be refactored in a future release)
260+
261+
caplog.set_level(logging.DEBUG)
262+
model_id = TOOL_LLM_ID
263+
with Client() as client:
264+
llm = client.llm.model(model_id)
265+
chat = Chat()
266+
chat.add_user_message(
267+
"Attempt to divide 1 by 0 using the tool. Explain the result."
268+
)
269+
tools = [divide]
270+
predictions: list[PredictionRoundResult] = []
271+
invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = []
272+
273+
def _handle_invalid_request(
274+
exc: LMStudioPredictionError, request: ToolCallRequest | None
275+
) -> None:
276+
if request is not None:
277+
invalid_requests.append((exc, request))
278+
279+
act_result = llm.act(
280+
chat,
281+
tools,
282+
handle_invalid_tool_request=_handle_invalid_request,
283+
on_prediction_completed=predictions.append,
284+
)
285+
assert len(predictions) > 1
286+
assert act_result.rounds == len(predictions)
287+
# Test case is currently suppressing exceptions inside the tool call
288+
assert invalid_requests == []
289+
# If the content checks prove flaky in practice, they can be dropped
290+
assert "divide" in predictions[-1].content
291+
assert "zero" in predictions[-1].content

0 commit comments

Comments
 (0)