diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 3bcd39f..c1f26ee 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -53,6 +53,7 @@ ActResult, AnyLoadConfig, AnyModelSpecifier, + AsyncToolCall, AvailableModelBase, ChannelEndpoint, ChannelHandler, @@ -1314,7 +1315,7 @@ async def act( # Do not force a final round when no limit is specified final_round_index = -1 round_counter = itertools.count() - llm_tool_args = ChatResponseEndpoint.parse_tools(tools) + llm_tool_args = ChatResponseEndpoint.parse_tools(tools, allow_async=True) del tools # Supply the round index to any endpoint callbacks that expect one round_index: int @@ -1345,6 +1346,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress # TODO: Implementation to this point is common between the sync and async APIs + # (aside from the allow_async flag when parsing the tool definitions) # Implementation past this point differs (as the sync API uses its own thread pool) # Request predictions until no more tool call requests are received in response @@ -1378,13 +1380,12 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: channel_cm = self._session._create_channel(endpoint) prediction_stream = AsyncPredictionStream(channel_cm, endpoint) tool_call_requests: list[ToolCallRequest] = [] - parsed_tool_calls: list[Callable[[], ToolCallResultData]] = [] + parsed_tool_calls: list[AsyncToolCall] = [] async for event in prediction_stream._iter_events(): if isinstance(event, PredictionToolCallEvent): tool_call_request = event.arg tool_call_requests.append(tool_call_request) - # TODO: Also handle async tool calls here - tool_call = endpoint.request_tool_call(tool_call_request) + tool_call = endpoint.request_tool_call_async(tool_call_request) parsed_tool_calls.append(tool_call) prediction = prediction_stream.result() self._logger.debug( @@ -1409,10 +1410,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: tool_call_futures, return_when=asyncio.FIRST_COMPLETED ) active_tool_calls = len(pending) - # TODO: Also handle async tool calls here - tool_call_futures.append( - asyncio.ensure_future(asyncio.to_thread(tool_call)) - ) + tool_call_futures.append(asyncio.ensure_future(tool_call())) active_tool_calls += 1 tool_call_results: list[ToolCallResultData] = [] for tool_call_request, tool_call_future in zip( diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 76662e2..610f953 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1188,8 +1188,11 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]): | ChannelCommonRxEvent ) -ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any]] +ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any], bool] ClientToolMap: TypeAlias = Mapping[str, ClientToolSpec] +SyncToolCall: TypeAlias = Callable[[], ToolCallResultData] +# Require a coroutine (not just any awaitable) for ensure_future compatibility +AsyncToolCall: TypeAlias = Callable[[], Coroutine[None, None, ToolCallResultData]] PredictionMessageCallback: TypeAlias = Callable[[AssistantResponse], Any] PredictionFirstTokenCallback: TypeAlias = Callable[[], Any] @@ -1450,9 +1453,9 @@ def _handle_failed_tool_request( return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id) # TODO: Reduce code duplication with the tools_provider plugin hook runner - def request_tool_call( + def _request_any_tool_call( self, request: ToolCallRequest - ) -> Callable[[], ToolCallResultData]: + ) -> tuple[DictObject | None, Callable[[], Any], bool]: tool_name = request.name tool_call_id = request.id client_tool = self._client_tools.get(tool_name, None) @@ -1461,9 +1464,9 @@ def request_tool_call( f"Cannot find tool with name {tool_name}.", request ) result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id) - return lambda: result + return None, lambda: result, False # Validate parameters against their specification - params_struct, implementation = client_tool + params_struct, implementation, is_async = client_tool raw_kwds = request.arguments try: parsed_kwds = convert(raw_kwds, params_struct) @@ -1472,10 +1475,20 @@ def request_tool_call( f"Failed to parse arguments for tool {tool_name}: {exc}", request ) result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id) - return lambda: result - kwds = to_builtins(parsed_kwds) - + return None, lambda: result, False + return to_builtins(parsed_kwds), implementation, is_async + + def request_tool_call(self, request: ToolCallRequest) -> SyncToolCall: + kwds, implementation, is_async = self._request_any_tool_call(request) + if kwds is None: + # Tool def parsing failed, implementation emits an error response + return implementation + if is_async: + msg = f"Asynchronous tool {request.name!r} is not supported in synchronous API" + raise LMStudioValueError(msg) # Allow caller to schedule the tool call request for background execution + tool_call_id = request.id + def _call_requested_tool() -> ToolCallResultData: call_result = implementation(**kwds) return ToolCallResultData( @@ -1485,6 +1498,35 @@ def _call_requested_tool() -> ToolCallResultData: return _call_requested_tool + def request_tool_call_async(self, request: ToolCallRequest) -> AsyncToolCall: + kwds, implementation, is_async = self._request_any_tool_call(request) + if kwds is None: + # Tool def parsing failed, implementation emits an error response + async def _awaitable_error_response() -> ToolCallResultData: + return cast(ToolCallResultData, implementation()) + + return _awaitable_error_response + # Allow caller to schedule the tool call request as a coroutine + tool_call_id = request.id + if is_async: + # Async tool implementation can be awaited directly + async def _call_requested_tool() -> ToolCallResultData: + call_result = await implementation(**kwds) + return ToolCallResultData( + content=json.dumps(call_result, ensure_ascii=False), + tool_call_id=tool_call_id, + ) + else: + # Sync tool implementation needs to be called in a thread + async def _call_requested_tool() -> ToolCallResultData: + call_result = await asyncio.to_thread(implementation, **kwds) + return ToolCallResultData( + content=json.dumps(call_result, ensure_ascii=False), + tool_call_id=tool_call_id, + ) + + return _call_requested_tool + def mark_cancelled(self) -> None: """Mark the prediction as cancelled and quietly drop incoming tokens.""" self._is_cancelled = True @@ -1541,8 +1583,11 @@ class ChatResponseEndpoint(PredictionEndpoint): @staticmethod def parse_tools( tools: Iterable[ToolDefinition], + allow_async: bool = False, ) -> tuple[LlmToolUseSettingToolArray, ClientToolMap]: """Split tool function definitions into server and client details.""" + from inspect import iscoroutinefunction + if not tools: raise LMStudioValueError( "Tool using actions require at least one tool to be defined." @@ -1561,7 +1606,12 @@ def parse_tools( f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)" ) params_struct, llm_tool_def = tool_def._to_llm_tool_def() - client_tool_map[tool_def.name] = (params_struct, tool_def.implementation) + tool_impl = tool_def.implementation + is_async = iscoroutinefunction(tool_impl) + if is_async and not allow_async: + msg = f"Asynchronous tool definition for {tool_def.name!r} is not supported in synchronous API" + raise LMStudioValueError(msg) + client_tool_map[tool_def.name] = (params_struct, tool_impl, is_async) llm_tool_defs.append(llm_tool_def) return LlmToolUseSettingToolArray(tools=llm_tool_defs), client_tool_map diff --git a/src/lmstudio/plugin/hooks/tools_provider.py b/src/lmstudio/plugin/hooks/tools_provider.py index 2fa08c4..77c16fa 100644 --- a/src/lmstudio/plugin/hooks/tools_provider.py +++ b/src/lmstudio/plugin/hooks/tools_provider.py @@ -292,10 +292,27 @@ async def start_tool_call(self, tool_call: ProvideToolsCallTool) -> None: await self._queue.put(tool_call) # TODO: Reduce code duplication with the ChatResponseEndpoint definition + async def _call_async_tool( + self, + call_id: str, + implementation: Callable[..., Awaitable[Any]], + kwds: DictObject, + send_json: SendMessageAsync, + ) -> PluginToolCallCompleteDict: + assert _LMS_TOOL_CALL_SYNC.get(None) is None + call_context = AsyncToolCallContext(self.session_id, call_id, send_json) + _LMS_TOOL_CALL_ASYNC.set(call_context) + call_result = await implementation(**kwds) + return PluginToolCallComplete( + session_id=self.session_id, + call_id=call_id, + result=call_result, + ).to_dict() + def _call_sync_tool( self, call_id: str, - sync_tool: Callable[..., Any], + implementation: Callable[..., Any], kwds: DictObject, send_json: SendMessageAsync, ) -> Awaitable[PluginToolCallCompleteDict]: @@ -305,7 +322,7 @@ def _call_sync_tool( def _call_requested_tool() -> PluginToolCallCompleteDict: assert _LMS_TOOL_CALL_ASYNC.get(None) is None _LMS_TOOL_CALL_SYNC.set(call_context) - call_result = sync_tool(**kwds) + call_result = implementation(**kwds) return PluginToolCallComplete( session_id=self.session_id, call_id=call_id, @@ -325,7 +342,7 @@ async def _call_tool_implementation( f"Plugin does not provide a tool named {tool_name!r}." ) # Validate parameters against their specification - params_struct, tool_impl = tool_details + params_struct, tool_impl, is_async = tool_details raw_kwds = tool_call.parameters try: parsed_kwds = convert(raw_kwds, params_struct) @@ -333,7 +350,10 @@ async def _call_tool_implementation( err_msg = f"Failed to parse arguments for tool {tool_name}: {exc}" raise ServerRequestError(err_msg) kwds = to_builtins(parsed_kwds) - # TODO: Also support async tool definitions and invocation + if is_async: + return await self._call_async_tool( + tool_call.call_id, tool_impl, kwds, send_json + ) return await self._call_sync_tool(tool_call.call_id, tool_impl, kwds, send_json) # TODO: Reduce code duplication with the ChatResponseEndpoint definition @@ -523,7 +543,8 @@ async def _invoke_hook( try: plugin_tools_list = await self.hook_impl(ctl) llm_tools_array, provided_tools = ChatResponseEndpoint.parse_tools( - plugin_tools_list + plugin_tools_list, + allow_async=True, ) llm_tools_list = llm_tools_array.to_dict()["tools"] assert llm_tools_list is not None # Ensured by the parse_tools method diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index d8c42ea..dfdaafa 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -1395,6 +1395,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress # TODO: Implementation to this point is common between the sync and async APIs + # (aside from the allow_async flag when parsing the tool definitions) # Implementation past this point differs (as the async API uses the loop's executor) # Request predictions until no more tool call requests are received in response diff --git a/tests/async/test_inference_async.py b/tests/async/test_inference_async.py index 76f506e..f8b16ac 100644 --- a/tests/async/test_inference_async.py +++ b/tests/async/test_inference_async.py @@ -39,7 +39,6 @@ SHORT_PREDICTION_CONFIG, TOOL_LLM_ID, check_sdk_error, - divide, ) @@ -479,6 +478,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: assert cloned_chat._messages == chat._messages +# Also check coroutine support in the asynchronous API +# (this becomes a regular sync tool in the sync API tests) +async def divide(numerator: float, denominator: float) -> float: + """Divide the given numerator by the given denominator. Return the result.""" + return numerator / denominator + + @pytest.mark.asyncio @pytest.mark.lmstudio async def test_tool_using_agent_error_handling_async(caplog: LogCap) -> None: diff --git a/tests/support/__init__.py b/tests/support/__init__.py index 0ec5c5d..5066542 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -301,11 +301,6 @@ def check_unfiltered_error( #################################################### -def divide(numerator: float, denominator: float) -> float: - """Divide the given numerator by the given denominator. Return the result.""" - return numerator / denominator - - def log_adding_two_integers(a: int, b: int) -> int: """Log adding two integers together.""" logging.info(f"Tool call: Adding {a!r} to {b!r} as integers") diff --git a/tests/sync/test_inference_sync.py b/tests/sync/test_inference_sync.py index 92dba9e..b519baf 100644 --- a/tests/sync/test_inference_sync.py +++ b/tests/sync/test_inference_sync.py @@ -46,7 +46,6 @@ SHORT_PREDICTION_CONFIG, TOOL_LLM_ID, check_sdk_error, - divide, ) @@ -464,6 +463,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: assert cloned_chat._messages == chat._messages +# Also check coroutine support in the asynchronous API +# (this becomes a regular sync tool in the sync API tests) +def divide(numerator: float, denominator: float) -> float: + """Divide the given numerator by the given denominator. Return the result.""" + return numerator / denominator + + @pytest.mark.lmstudio def test_tool_using_agent_error_handling_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) diff --git a/tests/test_inference.py b/tests/test_inference.py index 80ff704..0472b10 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -138,3 +138,22 @@ def test_duplicate_tool_names_rejected() -> None: LMStudioValueError, match="Duplicate tool names are not permitted" ): ChatResponseEndpoint.parse_tools(tools) + + +async def example_async_tool() -> int: + """Example asynchronous tool definition""" + return 42 + + +def test_async_tool_rejected() -> None: + tools: list[Any] = [example_async_tool] + with pytest.raises(LMStudioValueError, match=".*example_async_tool.*not supported"): + ChatResponseEndpoint.parse_tools(tools) + + +def test_async_tool_accepted() -> None: + tools: list[Any] = [example_async_tool] + llm_tools, client_map = ChatResponseEndpoint.parse_tools(tools, allow_async=True) + assert llm_tools.tools is not None + assert len(llm_tools.tools) == 1 + assert len(client_map) == 1