@@ -1188,8 +1188,11 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]):
11881188 | ChannelCommonRxEvent
11891189)
11901190
1191- ClientToolSpec : TypeAlias = tuple [type [Struct ], Callable [..., Any ]]
1191+ ClientToolSpec : TypeAlias = tuple [type [Struct ], Callable [..., Any ], bool ]
11921192ClientToolMap : TypeAlias = Mapping [str , ClientToolSpec ]
1193+ SyncToolCall : TypeAlias = Callable [[], ToolCallResultData ]
1194+ # Require a coroutine (not just any awaitable) for ensure_future compatibility
1195+ AsyncToolCall : TypeAlias = Callable [[], Coroutine [None , None , ToolCallResultData ]]
11931196
11941197PredictionMessageCallback : TypeAlias = Callable [[AssistantResponse ], Any ]
11951198PredictionFirstTokenCallback : TypeAlias = Callable [[], Any ]
@@ -1450,9 +1453,9 @@ def _handle_failed_tool_request(
14501453 return ToolCallResultData (content = json .dumps (err_msg ), tool_call_id = request .id )
14511454
14521455 # TODO: Reduce code duplication with the tools_provider plugin hook runner
1453- def request_tool_call (
1456+ def _request_any_tool_call (
14541457 self , request : ToolCallRequest
1455- ) -> Callable [[], ToolCallResultData ]:
1458+ ) -> tuple [ DictObject | None , Callable [[], Any ], bool ]:
14561459 tool_name = request .name
14571460 tool_call_id = request .id
14581461 client_tool = self ._client_tools .get (tool_name , None )
@@ -1461,9 +1464,9 @@ def request_tool_call(
14611464 f"Cannot find tool with name { tool_name } ." , request
14621465 )
14631466 result = ToolCallResultData (content = err_msg , tool_call_id = tool_call_id )
1464- return lambda : result
1467+ return None , lambda : result , False
14651468 # Validate parameters against their specification
1466- params_struct , implementation = client_tool
1469+ params_struct , implementation , is_async = client_tool
14671470 raw_kwds = request .arguments
14681471 try :
14691472 parsed_kwds = convert (raw_kwds , params_struct )
@@ -1472,10 +1475,20 @@ def request_tool_call(
14721475 f"Failed to parse arguments for tool { tool_name } : { exc } " , request
14731476 )
14741477 result = ToolCallResultData (content = err_msg , tool_call_id = tool_call_id )
1475- return lambda : result
1476- kwds = to_builtins (parsed_kwds )
1477-
1478+ return None , lambda : result , False
1479+ return to_builtins (parsed_kwds ), implementation , is_async
1480+
1481+ def request_tool_call (self , request : ToolCallRequest ) -> SyncToolCall :
1482+ kwds , implementation , is_async = self ._request_any_tool_call (request )
1483+ if kwds is None :
1484+ # Tool def parsing failed, implementation emits an error response
1485+ return implementation
1486+ if is_async :
1487+ msg = f"Asynchronous tool { request .name !r} is not supported in synchronous API"
1488+ raise LMStudioValueError (msg )
14781489 # Allow caller to schedule the tool call request for background execution
1490+ tool_call_id = request .id
1491+
14791492 def _call_requested_tool () -> ToolCallResultData :
14801493 call_result = implementation (** kwds )
14811494 return ToolCallResultData (
@@ -1485,6 +1498,35 @@ def _call_requested_tool() -> ToolCallResultData:
14851498
14861499 return _call_requested_tool
14871500
1501+ def request_tool_call_async (self , request : ToolCallRequest ) -> AsyncToolCall :
1502+ kwds , implementation , is_async = self ._request_any_tool_call (request )
1503+ if kwds is None :
1504+ # Tool def parsing failed, implementation emits an error response
1505+ async def _awaitable_error_response () -> ToolCallResultData :
1506+ return cast (ToolCallResultData , implementation ())
1507+
1508+ return _awaitable_error_response
1509+ # Allow caller to schedule the tool call request as a coroutine
1510+ tool_call_id = request .id
1511+ if is_async :
1512+ # Async tool implementation can be awaited directly
1513+ async def _call_requested_tool () -> ToolCallResultData :
1514+ call_result = await implementation (** kwds )
1515+ return ToolCallResultData (
1516+ content = json .dumps (call_result , ensure_ascii = False ),
1517+ tool_call_id = tool_call_id ,
1518+ )
1519+ else :
1520+ # Sync tool implementation needs to be called in a thread
1521+ async def _call_requested_tool () -> ToolCallResultData :
1522+ call_result = await asyncio .to_thread (implementation , ** kwds )
1523+ return ToolCallResultData (
1524+ content = json .dumps (call_result , ensure_ascii = False ),
1525+ tool_call_id = tool_call_id ,
1526+ )
1527+
1528+ return _call_requested_tool
1529+
14881530 def mark_cancelled (self ) -> None :
14891531 """Mark the prediction as cancelled and quietly drop incoming tokens."""
14901532 self ._is_cancelled = True
@@ -1541,8 +1583,11 @@ class ChatResponseEndpoint(PredictionEndpoint):
15411583 @staticmethod
15421584 def parse_tools (
15431585 tools : Iterable [ToolDefinition ],
1586+ allow_async : bool = False ,
15441587 ) -> tuple [LlmToolUseSettingToolArray , ClientToolMap ]:
15451588 """Split tool function definitions into server and client details."""
1589+ from inspect import iscoroutinefunction
1590+
15461591 if not tools :
15471592 raise LMStudioValueError (
15481593 "Tool using actions require at least one tool to be defined."
@@ -1561,7 +1606,12 @@ def parse_tools(
15611606 f"Duplicate tool names are not permitted ({ tool_def .name !r} repeated)"
15621607 )
15631608 params_struct , llm_tool_def = tool_def ._to_llm_tool_def ()
1564- client_tool_map [tool_def .name ] = (params_struct , tool_def .implementation )
1609+ tool_impl = tool_def .implementation
1610+ is_async = iscoroutinefunction (tool_impl )
1611+ if is_async and not allow_async :
1612+ msg = f"Asynchronous tool definition for { tool_def .name !r} is not supported in synchronous API"
1613+ raise LMStudioValueError (msg )
1614+ client_tool_map [tool_def .name ] = (params_struct , tool_impl , is_async )
15651615 llm_tool_defs .append (llm_tool_def )
15661616 return LlmToolUseSettingToolArray (tools = llm_tool_defs ), client_tool_map
15671617
0 commit comments