@@ -1210,8 +1210,11 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]):
12101210 | ChannelCommonRxEvent
12111211)
12121212
1213- ClientToolSpec : TypeAlias = tuple [type [Struct ], Callable [..., Any ]]
1213+ ClientToolSpec : TypeAlias = tuple [type [Struct ], Callable [..., Any ], bool ]
12141214ClientToolMap : TypeAlias = Mapping [str , ClientToolSpec ]
1215+ SyncToolCall : TypeAlias = Callable [[], ToolCallResultData ]
1216+ # Require a coroutine (not just any awaitable) for ensure_future compatibility
1217+ AsyncToolCall : TypeAlias = Callable [[], Coroutine [None , None , ToolCallResultData ]]
12151218
12161219PredictionMessageCallback : TypeAlias = Callable [[AssistantResponse ], Any ]
12171220PredictionFirstTokenCallback : TypeAlias = Callable [[], Any ]
@@ -1472,9 +1475,9 @@ def _handle_failed_tool_request(
14721475 return ToolCallResultData (content = json .dumps (err_msg ), tool_call_id = request .id )
14731476
14741477 # TODO: Reduce code duplication with the tools_provider plugin hook runner
1475- def request_tool_call (
1478+ def _request_any_tool_call (
14761479 self , request : ToolCallRequest
1477- ) -> Callable [[], ToolCallResultData ]:
1480+ ) -> tuple [ DictObject | None , Callable [[], Any ], bool ]:
14781481 tool_name = request .name
14791482 tool_call_id = request .id
14801483 client_tool = self ._client_tools .get (tool_name , None )
@@ -1483,9 +1486,9 @@ def request_tool_call(
14831486 f"Cannot find tool with name { tool_name } ." , request
14841487 )
14851488 result = ToolCallResultData (content = err_msg , tool_call_id = tool_call_id )
1486- return lambda : result
1489+ return None , lambda : result , False
14871490 # Validate parameters against their specification
1488- params_struct , implementation = client_tool
1491+ params_struct , implementation , is_async = client_tool
14891492 raw_kwds = request .arguments
14901493 try :
14911494 parsed_kwds = convert (raw_kwds , params_struct )
@@ -1494,10 +1497,20 @@ def request_tool_call(
14941497 f"Failed to parse arguments for tool { tool_name } : { exc } " , request
14951498 )
14961499 result = ToolCallResultData (content = err_msg , tool_call_id = tool_call_id )
1497- return lambda : result
1498- kwds = to_builtins (parsed_kwds )
1499-
1500+ return None , lambda : result , False
1501+ return to_builtins (parsed_kwds ), implementation , is_async
1502+
1503+ def request_tool_call (self , request : ToolCallRequest ) -> SyncToolCall :
1504+ kwds , implementation , is_async = self ._request_any_tool_call (request )
1505+ if kwds is None :
1506+ # Tool def parsing failed, implementation emits an error response
1507+ return implementation
1508+ if is_async :
1509+ msg = f"Asynchronous tool { request .name !r} is not supported in synchronous API"
1510+ raise LMStudioValueError (msg )
15001511 # Allow caller to schedule the tool call request for background execution
1512+ tool_call_id = request .id
1513+
15011514 def _call_requested_tool () -> ToolCallResultData :
15021515 call_result = implementation (** kwds )
15031516 return ToolCallResultData (
@@ -1507,6 +1520,35 @@ def _call_requested_tool() -> ToolCallResultData:
15071520
15081521 return _call_requested_tool
15091522
1523+ def request_tool_call_async (self , request : ToolCallRequest ) -> AsyncToolCall :
1524+ kwds , implementation , is_async = self ._request_any_tool_call (request )
1525+ if kwds is None :
1526+ # Tool def parsing failed, implementation emits an error response
1527+ async def _awaitable_error_response () -> ToolCallResultData :
1528+ return cast (ToolCallResultData , implementation ())
1529+
1530+ return _awaitable_error_response
1531+ # Allow caller to schedule the tool call request as a coroutine
1532+ tool_call_id = request .id
1533+ if is_async :
1534+ # Async tool implementation can be awaited directly
1535+ async def _call_requested_tool () -> ToolCallResultData :
1536+ call_result = await implementation (** kwds )
1537+ return ToolCallResultData (
1538+ content = json .dumps (call_result , ensure_ascii = False ),
1539+ tool_call_id = tool_call_id ,
1540+ )
1541+ else :
1542+ # Sync tool implementation needs to be called in a thread
1543+ async def _call_requested_tool () -> ToolCallResultData :
1544+ call_result = await asyncio .to_thread (implementation , ** kwds )
1545+ return ToolCallResultData (
1546+ content = json .dumps (call_result , ensure_ascii = False ),
1547+ tool_call_id = tool_call_id ,
1548+ )
1549+
1550+ return _call_requested_tool
1551+
15101552 def mark_cancelled (self ) -> None :
15111553 """Mark the prediction as cancelled and quietly drop incoming tokens."""
15121554 self ._is_cancelled = True
@@ -1563,8 +1605,11 @@ class ChatResponseEndpoint(PredictionEndpoint):
15631605 @staticmethod
15641606 def parse_tools (
15651607 tools : Iterable [ToolDefinition ],
1608+ allow_async : bool = False ,
15661609 ) -> tuple [LlmToolUseSettingToolArray , ClientToolMap ]:
15671610 """Split tool function definitions into server and client details."""
1611+ from inspect import iscoroutinefunction
1612+
15681613 if not tools :
15691614 raise LMStudioValueError (
15701615 "Tool using actions require at least one tool to be defined."
@@ -1583,7 +1628,12 @@ def parse_tools(
15831628 f"Duplicate tool names are not permitted ({ tool_def .name !r} repeated)"
15841629 )
15851630 params_struct , llm_tool_def = tool_def ._to_llm_tool_def ()
1586- client_tool_map [tool_def .name ] = (params_struct , tool_def .implementation )
1631+ tool_impl = tool_def .implementation
1632+ is_async = iscoroutinefunction (tool_impl )
1633+ if is_async and not allow_async :
1634+ msg = f"Asynchronous tool definition for { tool_def .name !r} is not supported in synchronous API"
1635+ raise LMStudioValueError (msg )
1636+ client_tool_map [tool_def .name ] = (params_struct , tool_impl , is_async )
15871637 llm_tool_defs .append (llm_tool_def )
15881638 return LlmToolUseSettingToolArray (tools = llm_tool_defs ), client_tool_map
15891639
0 commit comments