Skip to content

Commit 6660e52

Browse files
authored
Accept async callbacks in async .act() (#134)
Closes #133
1 parent 65a510b commit 6660e52

File tree

8 files changed

+125
-29
lines changed

8 files changed

+125
-29
lines changed

src/lmstudio/async_api.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
ActResult,
5454
AnyLoadConfig,
5555
AnyModelSpecifier,
56+
AsyncToolCall,
5657
AvailableModelBase,
5758
ChannelEndpoint,
5859
ChannelHandler,
@@ -1314,7 +1315,7 @@ async def act(
13141315
# Do not force a final round when no limit is specified
13151316
final_round_index = -1
13161317
round_counter = itertools.count()
1317-
llm_tool_args = ChatResponseEndpoint.parse_tools(tools)
1318+
llm_tool_args = ChatResponseEndpoint.parse_tools(tools, allow_async=True)
13181319
del tools
13191320
# Supply the round index to any endpoint callbacks that expect one
13201321
round_index: int
@@ -1345,6 +1346,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
13451346

13461347
on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
13471348
# TODO: Implementation to this point is common between the sync and async APIs
1349+
# (aside from the allow_async flag when parsing the tool definitions)
13481350
# Implementation past this point differs (as the sync API uses its own thread pool)
13491351

13501352
# Request predictions until no more tool call requests are received in response
@@ -1378,13 +1380,12 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
13781380
channel_cm = self._session._create_channel(endpoint)
13791381
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
13801382
tool_call_requests: list[ToolCallRequest] = []
1381-
parsed_tool_calls: list[Callable[[], ToolCallResultData]] = []
1383+
parsed_tool_calls: list[AsyncToolCall] = []
13821384
async for event in prediction_stream._iter_events():
13831385
if isinstance(event, PredictionToolCallEvent):
13841386
tool_call_request = event.arg
13851387
tool_call_requests.append(tool_call_request)
1386-
# TODO: Also handle async tool calls here
1387-
tool_call = endpoint.request_tool_call(tool_call_request)
1388+
tool_call = endpoint.request_tool_call_async(tool_call_request)
13881389
parsed_tool_calls.append(tool_call)
13891390
prediction = prediction_stream.result()
13901391
self._logger.debug(
@@ -1409,10 +1410,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
14091410
tool_call_futures, return_when=asyncio.FIRST_COMPLETED
14101411
)
14111412
active_tool_calls = len(pending)
1412-
# TODO: Also handle async tool calls here
1413-
tool_call_futures.append(
1414-
asyncio.ensure_future(asyncio.to_thread(tool_call))
1415-
)
1413+
tool_call_futures.append(asyncio.ensure_future(tool_call()))
14161414
active_tool_calls += 1
14171415
tool_call_results: list[ToolCallResultData] = []
14181416
for tool_call_request, tool_call_future in zip(

src/lmstudio/json_api.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,8 +1188,11 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]):
11881188
| ChannelCommonRxEvent
11891189
)
11901190

1191-
ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any]]
1191+
ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any], bool]
11921192
ClientToolMap: TypeAlias = Mapping[str, ClientToolSpec]
1193+
SyncToolCall: TypeAlias = Callable[[], ToolCallResultData]
1194+
# Require a coroutine (not just any awaitable) for ensure_future compatibility
1195+
AsyncToolCall: TypeAlias = Callable[[], Coroutine[None, None, ToolCallResultData]]
11931196

11941197
PredictionMessageCallback: TypeAlias = Callable[[AssistantResponse], Any]
11951198
PredictionFirstTokenCallback: TypeAlias = Callable[[], Any]
@@ -1450,9 +1453,9 @@ def _handle_failed_tool_request(
14501453
return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id)
14511454

14521455
# TODO: Reduce code duplication with the tools_provider plugin hook runner
1453-
def request_tool_call(
1456+
def _request_any_tool_call(
14541457
self, request: ToolCallRequest
1455-
) -> Callable[[], ToolCallResultData]:
1458+
) -> tuple[DictObject | None, Callable[[], Any], bool]:
14561459
tool_name = request.name
14571460
tool_call_id = request.id
14581461
client_tool = self._client_tools.get(tool_name, None)
@@ -1461,9 +1464,9 @@ def request_tool_call(
14611464
f"Cannot find tool with name {tool_name}.", request
14621465
)
14631466
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
1464-
return lambda: result
1467+
return None, lambda: result, False
14651468
# Validate parameters against their specification
1466-
params_struct, implementation = client_tool
1469+
params_struct, implementation, is_async = client_tool
14671470
raw_kwds = request.arguments
14681471
try:
14691472
parsed_kwds = convert(raw_kwds, params_struct)
@@ -1472,10 +1475,20 @@ def request_tool_call(
14721475
f"Failed to parse arguments for tool {tool_name}: {exc}", request
14731476
)
14741477
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
1475-
return lambda: result
1476-
kwds = to_builtins(parsed_kwds)
1477-
1478+
return None, lambda: result, False
1479+
return to_builtins(parsed_kwds), implementation, is_async
1480+
1481+
def request_tool_call(self, request: ToolCallRequest) -> SyncToolCall:
1482+
kwds, implementation, is_async = self._request_any_tool_call(request)
1483+
if kwds is None:
1484+
# Tool def parsing failed, implementation emits an error response
1485+
return implementation
1486+
if is_async:
1487+
msg = f"Asynchronous tool {request.name!r} is not supported in synchronous API"
1488+
raise LMStudioValueError(msg)
14781489
# Allow caller to schedule the tool call request for background execution
1490+
tool_call_id = request.id
1491+
14791492
def _call_requested_tool() -> ToolCallResultData:
14801493
call_result = implementation(**kwds)
14811494
return ToolCallResultData(
@@ -1485,6 +1498,35 @@ def _call_requested_tool() -> ToolCallResultData:
14851498

14861499
return _call_requested_tool
14871500

1501+
def request_tool_call_async(self, request: ToolCallRequest) -> AsyncToolCall:
1502+
kwds, implementation, is_async = self._request_any_tool_call(request)
1503+
if kwds is None:
1504+
# Tool def parsing failed, implementation emits an error response
1505+
async def _awaitable_error_response() -> ToolCallResultData:
1506+
return cast(ToolCallResultData, implementation())
1507+
1508+
return _awaitable_error_response
1509+
# Allow caller to schedule the tool call request as a coroutine
1510+
tool_call_id = request.id
1511+
if is_async:
1512+
# Async tool implementation can be awaited directly
1513+
async def _call_requested_tool() -> ToolCallResultData:
1514+
call_result = await implementation(**kwds)
1515+
return ToolCallResultData(
1516+
content=json.dumps(call_result, ensure_ascii=False),
1517+
tool_call_id=tool_call_id,
1518+
)
1519+
else:
1520+
# Sync tool implementation needs to be called in a thread
1521+
async def _call_requested_tool() -> ToolCallResultData:
1522+
call_result = await asyncio.to_thread(implementation, **kwds)
1523+
return ToolCallResultData(
1524+
content=json.dumps(call_result, ensure_ascii=False),
1525+
tool_call_id=tool_call_id,
1526+
)
1527+
1528+
return _call_requested_tool
1529+
14881530
def mark_cancelled(self) -> None:
14891531
"""Mark the prediction as cancelled and quietly drop incoming tokens."""
14901532
self._is_cancelled = True
@@ -1541,8 +1583,11 @@ class ChatResponseEndpoint(PredictionEndpoint):
15411583
@staticmethod
15421584
def parse_tools(
15431585
tools: Iterable[ToolDefinition],
1586+
allow_async: bool = False,
15441587
) -> tuple[LlmToolUseSettingToolArray, ClientToolMap]:
15451588
"""Split tool function definitions into server and client details."""
1589+
from inspect import iscoroutinefunction
1590+
15461591
if not tools:
15471592
raise LMStudioValueError(
15481593
"Tool using actions require at least one tool to be defined."
@@ -1561,7 +1606,12 @@ def parse_tools(
15611606
f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)"
15621607
)
15631608
params_struct, llm_tool_def = tool_def._to_llm_tool_def()
1564-
client_tool_map[tool_def.name] = (params_struct, tool_def.implementation)
1609+
tool_impl = tool_def.implementation
1610+
is_async = iscoroutinefunction(tool_impl)
1611+
if is_async and not allow_async:
1612+
msg = f"Asynchronous tool definition for {tool_def.name!r} is not supported in synchronous API"
1613+
raise LMStudioValueError(msg)
1614+
client_tool_map[tool_def.name] = (params_struct, tool_impl, is_async)
15651615
llm_tool_defs.append(llm_tool_def)
15661616
return LlmToolUseSettingToolArray(tools=llm_tool_defs), client_tool_map
15671617

src/lmstudio/plugin/hooks/tools_provider.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,27 @@ async def start_tool_call(self, tool_call: ProvideToolsCallTool) -> None:
292292
await self._queue.put(tool_call)
293293

294294
# TODO: Reduce code duplication with the ChatResponseEndpoint definition
295+
async def _call_async_tool(
296+
self,
297+
call_id: str,
298+
implementation: Callable[..., Awaitable[Any]],
299+
kwds: DictObject,
300+
send_json: SendMessageAsync,
301+
) -> PluginToolCallCompleteDict:
302+
assert _LMS_TOOL_CALL_SYNC.get(None) is None
303+
call_context = AsyncToolCallContext(self.session_id, call_id, send_json)
304+
_LMS_TOOL_CALL_ASYNC.set(call_context)
305+
call_result = await implementation(**kwds)
306+
return PluginToolCallComplete(
307+
session_id=self.session_id,
308+
call_id=call_id,
309+
result=call_result,
310+
).to_dict()
311+
295312
def _call_sync_tool(
296313
self,
297314
call_id: str,
298-
sync_tool: Callable[..., Any],
315+
implementation: Callable[..., Any],
299316
kwds: DictObject,
300317
send_json: SendMessageAsync,
301318
) -> Awaitable[PluginToolCallCompleteDict]:
@@ -305,7 +322,7 @@ def _call_sync_tool(
305322
def _call_requested_tool() -> PluginToolCallCompleteDict:
306323
assert _LMS_TOOL_CALL_ASYNC.get(None) is None
307324
_LMS_TOOL_CALL_SYNC.set(call_context)
308-
call_result = sync_tool(**kwds)
325+
call_result = implementation(**kwds)
309326
return PluginToolCallComplete(
310327
session_id=self.session_id,
311328
call_id=call_id,
@@ -325,15 +342,18 @@ async def _call_tool_implementation(
325342
f"Plugin does not provide a tool named {tool_name!r}."
326343
)
327344
# Validate parameters against their specification
328-
params_struct, tool_impl = tool_details
345+
params_struct, tool_impl, is_async = tool_details
329346
raw_kwds = tool_call.parameters
330347
try:
331348
parsed_kwds = convert(raw_kwds, params_struct)
332349
except Exception as exc:
333350
err_msg = f"Failed to parse arguments for tool {tool_name}: {exc}"
334351
raise ServerRequestError(err_msg)
335352
kwds = to_builtins(parsed_kwds)
336-
# TODO: Also support async tool definitions and invocation
353+
if is_async:
354+
return await self._call_async_tool(
355+
tool_call.call_id, tool_impl, kwds, send_json
356+
)
337357
return await self._call_sync_tool(tool_call.call_id, tool_impl, kwds, send_json)
338358

339359
# TODO: Reduce code duplication with the ChatResponseEndpoint definition
@@ -523,7 +543,8 @@ async def _invoke_hook(
523543
try:
524544
plugin_tools_list = await self.hook_impl(ctl)
525545
llm_tools_array, provided_tools = ChatResponseEndpoint.parse_tools(
526-
plugin_tools_list
546+
plugin_tools_list,
547+
allow_async=True,
527548
)
528549
llm_tools_list = llm_tools_array.to_dict()["tools"]
529550
assert llm_tools_list is not None # Ensured by the parse_tools method

src/lmstudio/sync_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
13951395

13961396
on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
13971397
# TODO: Implementation to this point is common between the sync and async APIs
1398+
# (aside from the allow_async flag when parsing the tool definitions)
13981399
# Implementation past this point differs (as the async API uses the loop's executor)
13991400

14001401
# Request predictions until no more tool call requests are received in response

tests/async/test_inference_async.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
SHORT_PREDICTION_CONFIG,
4040
TOOL_LLM_ID,
4141
check_sdk_error,
42-
divide,
4342
)
4443

4544

@@ -479,6 +478,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
479478
assert cloned_chat._messages == chat._messages
480479

481480

481+
# Also check coroutine support in the asynchronous API
482+
# (this becomes a regular sync tool in the sync API tests)
483+
async def divide(numerator: float, denominator: float) -> float:
484+
"""Divide the given numerator by the given denominator. Return the result."""
485+
return numerator / denominator
486+
487+
482488
@pytest.mark.asyncio
483489
@pytest.mark.lmstudio
484490
async def test_tool_using_agent_error_handling_async(caplog: LogCap) -> None:

tests/support/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,6 @@ def check_unfiltered_error(
301301
####################################################
302302

303303

304-
def divide(numerator: float, denominator: float) -> float:
305-
"""Divide the given numerator by the given denominator. Return the result."""
306-
return numerator / denominator
307-
308-
309304
def log_adding_two_integers(a: int, b: int) -> int:
310305
"""Log adding two integers together."""
311306
logging.info(f"Tool call: Adding {a!r} to {b!r} as integers")

tests/sync/test_inference_sync.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
SHORT_PREDICTION_CONFIG,
4747
TOOL_LLM_ID,
4848
check_sdk_error,
49-
divide,
5049
)
5150

5251

@@ -464,6 +463,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
464463
assert cloned_chat._messages == chat._messages
465464

466465

466+
# Also check coroutine support in the asynchronous API
467+
# (this becomes a regular sync tool in the sync API tests)
468+
def divide(numerator: float, denominator: float) -> float:
469+
"""Divide the given numerator by the given denominator. Return the result."""
470+
return numerator / denominator
471+
472+
467473
@pytest.mark.lmstudio
468474
def test_tool_using_agent_error_handling_sync(caplog: LogCap) -> None:
469475
caplog.set_level(logging.DEBUG)

tests/test_inference.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,22 @@ def test_duplicate_tool_names_rejected() -> None:
138138
LMStudioValueError, match="Duplicate tool names are not permitted"
139139
):
140140
ChatResponseEndpoint.parse_tools(tools)
141+
142+
143+
async def example_async_tool() -> int:
144+
"""Example asynchronous tool definition"""
145+
return 42
146+
147+
148+
def test_async_tool_rejected() -> None:
149+
tools: list[Any] = [example_async_tool]
150+
with pytest.raises(LMStudioValueError, match=".*example_async_tool.*not supported"):
151+
ChatResponseEndpoint.parse_tools(tools)
152+
153+
154+
def test_async_tool_accepted() -> None:
155+
tools: list[Any] = [example_async_tool]
156+
llm_tools, client_map = ChatResponseEndpoint.parse_tools(tools, allow_async=True)
157+
assert llm_tools.tools is not None
158+
assert len(llm_tools.tools) == 1
159+
assert len(client_map) == 1

0 commit comments

Comments
 (0)