Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ def __init__(
on_prompt_processing_progress: PromptProcessingCallback | None = None,
# The remaining options are only relevant for multi-round tool actions
handle_invalid_tool_request: Callable[
[LMStudioPredictionError, ToolCallRequest | None], str
[LMStudioPredictionError, ToolCallRequest | None], str | None
]
| None = None,
llm_tools: LlmToolUseSettingToolArray | None = None,
Expand Down Expand Up @@ -1336,12 +1336,14 @@ def _report_prompt_processing_progress(self, progress: float) -> None:
def _handle_invalid_tool_request(
self, err_msg: str, request: ToolCallRequest | None = None
) -> str:
exc = LMStudioPredictionError(err_msg)
_on_handle_invalid_tool_request = self._on_handle_invalid_tool_request
if _on_handle_invalid_tool_request is not None:
# Allow users to override the error message, or force an exception
self._logger.debug("Invoking on_handle_invalid_tool_request callback")
err_msg = _on_handle_invalid_tool_request(exc, request)
exc = LMStudioPredictionError(err_msg)
user_err_msg = _on_handle_invalid_tool_request(exc, request)
if user_err_msg is not None:
err_msg = user_err_msg
if request is not None:
return err_msg
# We don't allow users to prevent the exception when there's no request
Expand Down
2 changes: 1 addition & 1 deletion src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ def act(
on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None,
on_prompt_processing_progress: Callable[[float, int], Any] | None = None,
handle_invalid_tool_request: Callable[
[LMStudioPredictionError, ToolCallRequest | None], str
[LMStudioPredictionError, ToolCallRequest | None], str | None
]
| None = None,
) -> ActResult:
Expand Down
52 changes: 50 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
Client,
LlmPredictionConfig,
LlmPredictionFragment,
LMStudioPredictionError,
LMStudioValueError,
PredictionResult,
PredictionRoundResult,
ToolCallRequest,
ToolFunctionDef,
ToolFunctionDefDict,
)
Expand Down Expand Up @@ -162,7 +164,7 @@ def test_duplicate_tool_names_rejected() -> None:

@pytest.mark.lmstudio
def test_tool_using_agent(caplog: LogCap) -> None:
# This is currently a sync-only API (it will be refactored after 1.0.0)
# This is currently a sync-only API (it will be refactored in a future release)

caplog.set_level(logging.DEBUG)
model_id = TOOL_LLM_ID
Expand Down Expand Up @@ -192,7 +194,7 @@ def test_tool_using_agent(caplog: LogCap) -> None:

@pytest.mark.lmstudio
def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
# This is currently a sync-only API (it will be refactored after 1.0.0)
# This is currently a sync-only API (it will be refactored in a future release)

caplog.set_level(logging.DEBUG)
model_id = TOOL_LLM_ID
Expand Down Expand Up @@ -241,3 +243,49 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:

cloned_chat = chat.copy()
assert cloned_chat._messages == chat._messages


def divide(numerator: float, denominator: float) -> float | str:
"""Divide the given numerator by the given denominator. Return the result."""
try:
return numerator / denominator
except Exception as exc:
# TODO: Perform this exception-to-response-string translation implicitly
return f"Unhandled Python exception: {exc!r}"


@pytest.mark.lmstudio
def test_tool_using_agent_error_handling(caplog: LogCap) -> None:
# This is currently a sync-only API (it will be refactored in a future release)

caplog.set_level(logging.DEBUG)
model_id = TOOL_LLM_ID
with Client() as client:
llm = client.llm.model(model_id)
chat = Chat()
chat.add_user_message(
"Attempt to divide 1 by 0 using the tool. Explain the result."
)
tools = [divide]
predictions: list[PredictionRoundResult] = []
invalid_requests: list[tuple[LMStudioPredictionError, ToolCallRequest]] = []

def _handle_invalid_request(
exc: LMStudioPredictionError, request: ToolCallRequest | None
) -> None:
if request is not None:
invalid_requests.append((exc, request))

act_result = llm.act(
chat,
tools,
handle_invalid_tool_request=_handle_invalid_request,
on_prediction_completed=predictions.append,
)
assert len(predictions) > 1
assert act_result.rounds == len(predictions)
# Test case is currently suppressing exceptions inside the tool call
assert invalid_requests == []
# If the content checks prove flaky in practice, they can be dropped
assert "divide" in predictions[-1].content
assert "zero" in predictions[-1].content