Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
ActResult,
AnyLoadConfig,
AnyModelSpecifier,
AsyncToolCall,
AvailableModelBase,
ChannelEndpoint,
ChannelHandler,
Expand Down Expand Up @@ -1314,7 +1315,7 @@ async def act(
# Do not force a final round when no limit is specified
final_round_index = -1
round_counter = itertools.count()
llm_tool_args = ChatResponseEndpoint.parse_tools(tools)
llm_tool_args = ChatResponseEndpoint.parse_tools(tools, allow_async=True)
del tools
# Supply the round index to any endpoint callbacks that expect one
round_index: int
Expand Down Expand Up @@ -1345,6 +1346,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:

on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
# TODO: Implementation to this point is common between the sync and async APIs
# (aside from the allow_async flag when parsing the tool definitions)
# Implementation past this point differs (as the sync API uses its own thread pool)

# Request predictions until no more tool call requests are received in response
Expand Down Expand Up @@ -1378,13 +1380,12 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
channel_cm = self._session._create_channel(endpoint)
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
tool_call_requests: list[ToolCallRequest] = []
parsed_tool_calls: list[Callable[[], ToolCallResultData]] = []
parsed_tool_calls: list[AsyncToolCall] = []
async for event in prediction_stream._iter_events():
if isinstance(event, PredictionToolCallEvent):
tool_call_request = event.arg
tool_call_requests.append(tool_call_request)
# TODO: Also handle async tool calls here
tool_call = endpoint.request_tool_call(tool_call_request)
tool_call = endpoint.request_tool_call_async(tool_call_request)
parsed_tool_calls.append(tool_call)
prediction = prediction_stream.result()
self._logger.debug(
Expand All @@ -1409,10 +1410,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
tool_call_futures, return_when=asyncio.FIRST_COMPLETED
)
active_tool_calls = len(pending)
# TODO: Also handle async tool calls here
tool_call_futures.append(
asyncio.ensure_future(asyncio.to_thread(tool_call))
)
tool_call_futures.append(asyncio.ensure_future(tool_call()))
active_tool_calls += 1
tool_call_results: list[ToolCallResultData] = []
for tool_call_request, tool_call_future in zip(
Expand Down
68 changes: 59 additions & 9 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,8 +1188,11 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]):
| ChannelCommonRxEvent
)

ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any]]
ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any], bool]
ClientToolMap: TypeAlias = Mapping[str, ClientToolSpec]
SyncToolCall: TypeAlias = Callable[[], ToolCallResultData]
# Require a coroutine (not just any awaitable) for ensure_future compatibility
AsyncToolCall: TypeAlias = Callable[[], Coroutine[None, None, ToolCallResultData]]

PredictionMessageCallback: TypeAlias = Callable[[AssistantResponse], Any]
PredictionFirstTokenCallback: TypeAlias = Callable[[], Any]
Expand Down Expand Up @@ -1450,9 +1453,9 @@ def _handle_failed_tool_request(
return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id)

# TODO: Reduce code duplication with the tools_provider plugin hook runner
def request_tool_call(
def _request_any_tool_call(
self, request: ToolCallRequest
) -> Callable[[], ToolCallResultData]:
) -> tuple[DictObject | None, Callable[[], Any], bool]:
tool_name = request.name
tool_call_id = request.id
client_tool = self._client_tools.get(tool_name, None)
Expand All @@ -1461,9 +1464,9 @@ def request_tool_call(
f"Cannot find tool with name {tool_name}.", request
)
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
return lambda: result
return None, lambda: result, False
# Validate parameters against their specification
params_struct, implementation = client_tool
params_struct, implementation, is_async = client_tool
raw_kwds = request.arguments
try:
parsed_kwds = convert(raw_kwds, params_struct)
Expand All @@ -1472,10 +1475,20 @@ def request_tool_call(
f"Failed to parse arguments for tool {tool_name}: {exc}", request
)
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
return lambda: result
kwds = to_builtins(parsed_kwds)

return None, lambda: result, False
return to_builtins(parsed_kwds), implementation, is_async

def request_tool_call(self, request: ToolCallRequest) -> SyncToolCall:
kwds, implementation, is_async = self._request_any_tool_call(request)
if kwds is None:
# Tool def parsing failed, implementation emits an error response
return implementation
if is_async:
msg = f"Asynchronous tool {request.name!r} is not supported in synchronous API"
raise LMStudioValueError(msg)
# Allow caller to schedule the tool call request for background execution
tool_call_id = request.id

def _call_requested_tool() -> ToolCallResultData:
call_result = implementation(**kwds)
return ToolCallResultData(
Expand All @@ -1485,6 +1498,35 @@ def _call_requested_tool() -> ToolCallResultData:

return _call_requested_tool

def request_tool_call_async(self, request: ToolCallRequest) -> AsyncToolCall:
kwds, implementation, is_async = self._request_any_tool_call(request)
if kwds is None:
# Tool def parsing failed, implementation emits an error response
async def _awaitable_error_response() -> ToolCallResultData:
return cast(ToolCallResultData, implementation())

return _awaitable_error_response
# Allow caller to schedule the tool call request as a coroutine
tool_call_id = request.id
if is_async:
# Async tool implementation can be awaited directly
async def _call_requested_tool() -> ToolCallResultData:
call_result = await implementation(**kwds)
return ToolCallResultData(
content=json.dumps(call_result, ensure_ascii=False),
tool_call_id=tool_call_id,
)
else:
# Sync tool implementation needs to be called in a thread
async def _call_requested_tool() -> ToolCallResultData:
call_result = await asyncio.to_thread(implementation, **kwds)
return ToolCallResultData(
content=json.dumps(call_result, ensure_ascii=False),
tool_call_id=tool_call_id,
)

return _call_requested_tool

def mark_cancelled(self) -> None:
"""Mark the prediction as cancelled and quietly drop incoming tokens."""
self._is_cancelled = True
Expand Down Expand Up @@ -1541,8 +1583,11 @@ class ChatResponseEndpoint(PredictionEndpoint):
@staticmethod
def parse_tools(
tools: Iterable[ToolDefinition],
allow_async: bool = False,
) -> tuple[LlmToolUseSettingToolArray, ClientToolMap]:
"""Split tool function definitions into server and client details."""
from inspect import iscoroutinefunction

if not tools:
raise LMStudioValueError(
"Tool using actions require at least one tool to be defined."
Expand All @@ -1561,7 +1606,12 @@ def parse_tools(
f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)"
)
params_struct, llm_tool_def = tool_def._to_llm_tool_def()
client_tool_map[tool_def.name] = (params_struct, tool_def.implementation)
tool_impl = tool_def.implementation
is_async = iscoroutinefunction(tool_impl)
if is_async and not allow_async:
msg = f"Asynchronous tool definition for {tool_def.name!r} is not supported in synchronous API"
raise LMStudioValueError(msg)
client_tool_map[tool_def.name] = (params_struct, tool_impl, is_async)
llm_tool_defs.append(llm_tool_def)
return LlmToolUseSettingToolArray(tools=llm_tool_defs), client_tool_map

Expand Down
31 changes: 26 additions & 5 deletions src/lmstudio/plugin/hooks/tools_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,27 @@ async def start_tool_call(self, tool_call: ProvideToolsCallTool) -> None:
await self._queue.put(tool_call)

# TODO: Reduce code duplication with the ChatResponseEndpoint definition
async def _call_async_tool(
self,
call_id: str,
implementation: Callable[..., Awaitable[Any]],
kwds: DictObject,
send_json: SendMessageAsync,
) -> PluginToolCallCompleteDict:
assert _LMS_TOOL_CALL_SYNC.get(None) is None
call_context = AsyncToolCallContext(self.session_id, call_id, send_json)
_LMS_TOOL_CALL_ASYNC.set(call_context)
call_result = await implementation(**kwds)
return PluginToolCallComplete(
session_id=self.session_id,
call_id=call_id,
result=call_result,
).to_dict()

def _call_sync_tool(
self,
call_id: str,
sync_tool: Callable[..., Any],
implementation: Callable[..., Any],
kwds: DictObject,
send_json: SendMessageAsync,
) -> Awaitable[PluginToolCallCompleteDict]:
Expand All @@ -305,7 +322,7 @@ def _call_sync_tool(
def _call_requested_tool() -> PluginToolCallCompleteDict:
assert _LMS_TOOL_CALL_ASYNC.get(None) is None
_LMS_TOOL_CALL_SYNC.set(call_context)
call_result = sync_tool(**kwds)
call_result = implementation(**kwds)
return PluginToolCallComplete(
session_id=self.session_id,
call_id=call_id,
Expand All @@ -325,15 +342,18 @@ async def _call_tool_implementation(
f"Plugin does not provide a tool named {tool_name!r}."
)
# Validate parameters against their specification
params_struct, tool_impl = tool_details
params_struct, tool_impl, is_async = tool_details
raw_kwds = tool_call.parameters
try:
parsed_kwds = convert(raw_kwds, params_struct)
except Exception as exc:
err_msg = f"Failed to parse arguments for tool {tool_name}: {exc}"
raise ServerRequestError(err_msg)
kwds = to_builtins(parsed_kwds)
# TODO: Also support async tool definitions and invocation
if is_async:
return await self._call_async_tool(
tool_call.call_id, tool_impl, kwds, send_json
)
return await self._call_sync_tool(tool_call.call_id, tool_impl, kwds, send_json)

# TODO: Reduce code duplication with the ChatResponseEndpoint definition
Expand Down Expand Up @@ -523,7 +543,8 @@ async def _invoke_hook(
try:
plugin_tools_list = await self.hook_impl(ctl)
llm_tools_array, provided_tools = ChatResponseEndpoint.parse_tools(
plugin_tools_list
plugin_tools_list,
allow_async=True,
)
llm_tools_list = llm_tools_array.to_dict()["tools"]
assert llm_tools_list is not None # Ensured by the parse_tools method
Expand Down
1 change: 1 addition & 0 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:

on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
# TODO: Implementation to this point is common between the sync and async APIs
# (aside from the allow_async flag when parsing the tool definitions)
# Implementation past this point differs (as the async API uses the loop's executor)

# Request predictions until no more tool call requests are received in response
Expand Down
8 changes: 7 additions & 1 deletion tests/async/test_inference_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
SHORT_PREDICTION_CONFIG,
TOOL_LLM_ID,
check_sdk_error,
divide,
)


Expand Down Expand Up @@ -479,6 +478,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
assert cloned_chat._messages == chat._messages


# Also check coroutine support in the asynchronous API
# (this becomes a regular sync tool in the sync API tests)
async def divide(numerator: float, denominator: float) -> float:
"""Divide the given numerator by the given denominator. Return the result."""
return numerator / denominator


@pytest.mark.asyncio
@pytest.mark.lmstudio
async def test_tool_using_agent_error_handling_async(caplog: LogCap) -> None:
Expand Down
5 changes: 0 additions & 5 deletions tests/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,6 @@ def check_unfiltered_error(
####################################################


def divide(numerator: float, denominator: float) -> float:
"""Divide the given numerator by the given denominator. Return the result."""
return numerator / denominator


def log_adding_two_integers(a: int, b: int) -> int:
"""Log adding two integers together."""
logging.info(f"Tool call: Adding {a!r} to {b!r} as integers")
Expand Down
8 changes: 7 additions & 1 deletion tests/sync/test_inference_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
SHORT_PREDICTION_CONFIG,
TOOL_LLM_ID,
check_sdk_error,
divide,
)


Expand Down Expand Up @@ -464,6 +463,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
assert cloned_chat._messages == chat._messages


# Also check coroutine support in the asynchronous API
# (this becomes a regular sync tool in the sync API tests)
def divide(numerator: float, denominator: float) -> float:
"""Divide the given numerator by the given denominator. Return the result."""
return numerator / denominator


@pytest.mark.lmstudio
def test_tool_using_agent_error_handling_sync(caplog: LogCap) -> None:
caplog.set_level(logging.DEBUG)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,22 @@ def test_duplicate_tool_names_rejected() -> None:
LMStudioValueError, match="Duplicate tool names are not permitted"
):
ChatResponseEndpoint.parse_tools(tools)


async def example_async_tool() -> int:
"""Example asynchronous tool definition"""
return 42


def test_async_tool_rejected() -> None:
tools: list[Any] = [example_async_tool]
with pytest.raises(LMStudioValueError, match=".*example_async_tool.*not supported"):
ChatResponseEndpoint.parse_tools(tools)


def test_async_tool_accepted() -> None:
tools: list[Any] = [example_async_tool]
llm_tools, client_map = ChatResponseEndpoint.parse_tools(tools, allow_async=True)
assert llm_tools.tools is not None
assert len(llm_tools.tools) == 1
assert len(client_map) == 1