Skip to content

Commit 0f2bacd

Browse files
authored
Merge branch 'main' into main
2 parents 6a5e242 + 6660e52 commit 0f2bacd

File tree

8 files changed

+125
-29
lines changed

8 files changed

+125
-29
lines changed

src/lmstudio/async_api.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
ActResult,
5454
AnyLoadConfig,
5555
AnyModelSpecifier,
56+
AsyncToolCall,
5657
AvailableModelBase,
5758
ChannelEndpoint,
5859
ChannelHandler,
@@ -1314,7 +1315,7 @@ async def act(
13141315
# Do not force a final round when no limit is specified
13151316
final_round_index = -1
13161317
round_counter = itertools.count()
1317-
llm_tool_args = ChatResponseEndpoint.parse_tools(tools)
1318+
llm_tool_args = ChatResponseEndpoint.parse_tools(tools, allow_async=True)
13181319
del tools
13191320
# Supply the round index to any endpoint callbacks that expect one
13201321
round_index: int
@@ -1345,6 +1346,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
13451346

13461347
on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
13471348
# TODO: Implementation to this point is common between the sync and async APIs
1349+
# (aside from the allow_async flag when parsing the tool definitions)
13481350
# Implementation past this point differs (as the sync API uses its own thread pool)
13491351

13501352
# Request predictions until no more tool call requests are received in response
@@ -1378,13 +1380,12 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
13781380
channel_cm = self._session._create_channel(endpoint)
13791381
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
13801382
tool_call_requests: list[ToolCallRequest] = []
1381-
parsed_tool_calls: list[Callable[[], ToolCallResultData]] = []
1383+
parsed_tool_calls: list[AsyncToolCall] = []
13821384
async for event in prediction_stream._iter_events():
13831385
if isinstance(event, PredictionToolCallEvent):
13841386
tool_call_request = event.arg
13851387
tool_call_requests.append(tool_call_request)
1386-
# TODO: Also handle async tool calls here
1387-
tool_call = endpoint.request_tool_call(tool_call_request)
1388+
tool_call = endpoint.request_tool_call_async(tool_call_request)
13881389
parsed_tool_calls.append(tool_call)
13891390
prediction = prediction_stream.result()
13901391
self._logger.debug(
@@ -1409,10 +1410,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
14091410
tool_call_futures, return_when=asyncio.FIRST_COMPLETED
14101411
)
14111412
active_tool_calls = len(pending)
1412-
# TODO: Also handle async tool calls here
1413-
tool_call_futures.append(
1414-
asyncio.ensure_future(asyncio.to_thread(tool_call))
1415-
)
1413+
tool_call_futures.append(asyncio.ensure_future(tool_call()))
14161414
active_tool_calls += 1
14171415
tool_call_results: list[ToolCallResultData] = []
14181416
for tool_call_request, tool_call_future in zip(

src/lmstudio/json_api.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,8 +1210,11 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]):
12101210
| ChannelCommonRxEvent
12111211
)
12121212

1213-
ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any]]
1213+
ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any], bool]
12141214
ClientToolMap: TypeAlias = Mapping[str, ClientToolSpec]
1215+
SyncToolCall: TypeAlias = Callable[[], ToolCallResultData]
1216+
# Require a coroutine (not just any awaitable) for ensure_future compatibility
1217+
AsyncToolCall: TypeAlias = Callable[[], Coroutine[None, None, ToolCallResultData]]
12151218

12161219
PredictionMessageCallback: TypeAlias = Callable[[AssistantResponse], Any]
12171220
PredictionFirstTokenCallback: TypeAlias = Callable[[], Any]
@@ -1472,9 +1475,9 @@ def _handle_failed_tool_request(
14721475
return ToolCallResultData(content=json.dumps(err_msg), tool_call_id=request.id)
14731476

14741477
# TODO: Reduce code duplication with the tools_provider plugin hook runner
1475-
def request_tool_call(
1478+
def _request_any_tool_call(
14761479
self, request: ToolCallRequest
1477-
) -> Callable[[], ToolCallResultData]:
1480+
) -> tuple[DictObject | None, Callable[[], Any], bool]:
14781481
tool_name = request.name
14791482
tool_call_id = request.id
14801483
client_tool = self._client_tools.get(tool_name, None)
@@ -1483,9 +1486,9 @@ def request_tool_call(
14831486
f"Cannot find tool with name {tool_name}.", request
14841487
)
14851488
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
1486-
return lambda: result
1489+
return None, lambda: result, False
14871490
# Validate parameters against their specification
1488-
params_struct, implementation = client_tool
1491+
params_struct, implementation, is_async = client_tool
14891492
raw_kwds = request.arguments
14901493
try:
14911494
parsed_kwds = convert(raw_kwds, params_struct)
@@ -1494,10 +1497,20 @@ def request_tool_call(
14941497
f"Failed to parse arguments for tool {tool_name}: {exc}", request
14951498
)
14961499
result = ToolCallResultData(content=err_msg, tool_call_id=tool_call_id)
1497-
return lambda: result
1498-
kwds = to_builtins(parsed_kwds)
1499-
1500+
return None, lambda: result, False
1501+
return to_builtins(parsed_kwds), implementation, is_async
1502+
1503+
def request_tool_call(self, request: ToolCallRequest) -> SyncToolCall:
1504+
kwds, implementation, is_async = self._request_any_tool_call(request)
1505+
if kwds is None:
1506+
# Tool def parsing failed, implementation emits an error response
1507+
return implementation
1508+
if is_async:
1509+
msg = f"Asynchronous tool {request.name!r} is not supported in synchronous API"
1510+
raise LMStudioValueError(msg)
15001511
# Allow caller to schedule the tool call request for background execution
1512+
tool_call_id = request.id
1513+
15011514
def _call_requested_tool() -> ToolCallResultData:
15021515
call_result = implementation(**kwds)
15031516
return ToolCallResultData(
@@ -1507,6 +1520,35 @@ def _call_requested_tool() -> ToolCallResultData:
15071520

15081521
return _call_requested_tool
15091522

1523+
def request_tool_call_async(self, request: ToolCallRequest) -> AsyncToolCall:
1524+
kwds, implementation, is_async = self._request_any_tool_call(request)
1525+
if kwds is None:
1526+
# Tool def parsing failed, implementation emits an error response
1527+
async def _awaitable_error_response() -> ToolCallResultData:
1528+
return cast(ToolCallResultData, implementation())
1529+
1530+
return _awaitable_error_response
1531+
# Allow caller to schedule the tool call request as a coroutine
1532+
tool_call_id = request.id
1533+
if is_async:
1534+
# Async tool implementation can be awaited directly
1535+
async def _call_requested_tool() -> ToolCallResultData:
1536+
call_result = await implementation(**kwds)
1537+
return ToolCallResultData(
1538+
content=json.dumps(call_result, ensure_ascii=False),
1539+
tool_call_id=tool_call_id,
1540+
)
1541+
else:
1542+
# Sync tool implementation needs to be called in a thread
1543+
async def _call_requested_tool() -> ToolCallResultData:
1544+
call_result = await asyncio.to_thread(implementation, **kwds)
1545+
return ToolCallResultData(
1546+
content=json.dumps(call_result, ensure_ascii=False),
1547+
tool_call_id=tool_call_id,
1548+
)
1549+
1550+
return _call_requested_tool
1551+
15101552
def mark_cancelled(self) -> None:
15111553
"""Mark the prediction as cancelled and quietly drop incoming tokens."""
15121554
self._is_cancelled = True
@@ -1563,8 +1605,11 @@ class ChatResponseEndpoint(PredictionEndpoint):
15631605
@staticmethod
15641606
def parse_tools(
15651607
tools: Iterable[ToolDefinition],
1608+
allow_async: bool = False,
15661609
) -> tuple[LlmToolUseSettingToolArray, ClientToolMap]:
15671610
"""Split tool function definitions into server and client details."""
1611+
from inspect import iscoroutinefunction
1612+
15681613
if not tools:
15691614
raise LMStudioValueError(
15701615
"Tool using actions require at least one tool to be defined."
@@ -1583,7 +1628,12 @@ def parse_tools(
15831628
f"Duplicate tool names are not permitted ({tool_def.name!r} repeated)"
15841629
)
15851630
params_struct, llm_tool_def = tool_def._to_llm_tool_def()
1586-
client_tool_map[tool_def.name] = (params_struct, tool_def.implementation)
1631+
tool_impl = tool_def.implementation
1632+
is_async = iscoroutinefunction(tool_impl)
1633+
if is_async and not allow_async:
1634+
msg = f"Asynchronous tool definition for {tool_def.name!r} is not supported in synchronous API"
1635+
raise LMStudioValueError(msg)
1636+
client_tool_map[tool_def.name] = (params_struct, tool_impl, is_async)
15871637
llm_tool_defs.append(llm_tool_def)
15881638
return LlmToolUseSettingToolArray(tools=llm_tool_defs), client_tool_map
15891639

src/lmstudio/plugin/hooks/tools_provider.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,27 @@ async def start_tool_call(self, tool_call: ProvideToolsCallTool) -> None:
292292
await self._queue.put(tool_call)
293293

294294
# TODO: Reduce code duplication with the ChatResponseEndpoint definition
295+
async def _call_async_tool(
296+
self,
297+
call_id: str,
298+
implementation: Callable[..., Awaitable[Any]],
299+
kwds: DictObject,
300+
send_json: SendMessageAsync,
301+
) -> PluginToolCallCompleteDict:
302+
assert _LMS_TOOL_CALL_SYNC.get(None) is None
303+
call_context = AsyncToolCallContext(self.session_id, call_id, send_json)
304+
_LMS_TOOL_CALL_ASYNC.set(call_context)
305+
call_result = await implementation(**kwds)
306+
return PluginToolCallComplete(
307+
session_id=self.session_id,
308+
call_id=call_id,
309+
result=call_result,
310+
).to_dict()
311+
295312
def _call_sync_tool(
296313
self,
297314
call_id: str,
298-
sync_tool: Callable[..., Any],
315+
implementation: Callable[..., Any],
299316
kwds: DictObject,
300317
send_json: SendMessageAsync,
301318
) -> Awaitable[PluginToolCallCompleteDict]:
@@ -305,7 +322,7 @@ def _call_sync_tool(
305322
def _call_requested_tool() -> PluginToolCallCompleteDict:
306323
assert _LMS_TOOL_CALL_ASYNC.get(None) is None
307324
_LMS_TOOL_CALL_SYNC.set(call_context)
308-
call_result = sync_tool(**kwds)
325+
call_result = implementation(**kwds)
309326
return PluginToolCallComplete(
310327
session_id=self.session_id,
311328
call_id=call_id,
@@ -325,15 +342,18 @@ async def _call_tool_implementation(
325342
f"Plugin does not provide a tool named {tool_name!r}."
326343
)
327344
# Validate parameters against their specification
328-
params_struct, tool_impl = tool_details
345+
params_struct, tool_impl, is_async = tool_details
329346
raw_kwds = tool_call.parameters
330347
try:
331348
parsed_kwds = convert(raw_kwds, params_struct)
332349
except Exception as exc:
333350
err_msg = f"Failed to parse arguments for tool {tool_name}: {exc}"
334351
raise ServerRequestError(err_msg)
335352
kwds = to_builtins(parsed_kwds)
336-
# TODO: Also support async tool definitions and invocation
353+
if is_async:
354+
return await self._call_async_tool(
355+
tool_call.call_id, tool_impl, kwds, send_json
356+
)
337357
return await self._call_sync_tool(tool_call.call_id, tool_impl, kwds, send_json)
338358

339359
# TODO: Reduce code duplication with the ChatResponseEndpoint definition
@@ -523,7 +543,8 @@ async def _invoke_hook(
523543
try:
524544
plugin_tools_list = await self.hook_impl(ctl)
525545
llm_tools_array, provided_tools = ChatResponseEndpoint.parse_tools(
526-
plugin_tools_list
546+
plugin_tools_list,
547+
allow_async=True,
527548
)
528549
llm_tools_list = llm_tools_array.to_dict()["tools"]
529550
assert llm_tools_list is not None # Ensured by the parse_tools method

src/lmstudio/sync_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
13951395

13961396
on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
13971397
# TODO: Implementation to this point is common between the sync and async APIs
1398+
# (aside from the allow_async flag when parsing the tool definitions)
13981399
# Implementation past this point differs (as the async API uses the loop's executor)
13991400

14001401
# Request predictions until no more tool call requests are received in response

tests/async/test_inference_async.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
SHORT_PREDICTION_CONFIG,
4040
TOOL_LLM_ID,
4141
check_sdk_error,
42-
divide,
4342
)
4443

4544

@@ -479,6 +478,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
479478
assert cloned_chat._messages == chat._messages
480479

481480

481+
# Also check coroutine support in the asynchronous API
482+
# (this becomes a regular sync tool in the sync API tests)
483+
async def divide(numerator: float, denominator: float) -> float:
484+
"""Divide the given numerator by the given denominator. Return the result."""
485+
return numerator / denominator
486+
487+
482488
@pytest.mark.asyncio
483489
@pytest.mark.lmstudio
484490
async def test_tool_using_agent_error_handling_async(caplog: LogCap) -> None:

tests/support/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,6 @@ def check_unfiltered_error(
301301
####################################################
302302

303303

304-
def divide(numerator: float, denominator: float) -> float:
305-
"""Divide the given numerator by the given denominator. Return the result."""
306-
return numerator / denominator
307-
308-
309304
def log_adding_two_integers(a: int, b: int) -> int:
310305
"""Log adding two integers together."""
311306
logging.info(f"Tool call: Adding {a!r} to {b!r} as integers")

tests/sync/test_inference_sync.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
SHORT_PREDICTION_CONFIG,
4747
TOOL_LLM_ID,
4848
check_sdk_error,
49-
divide,
5049
)
5150

5251

@@ -464,6 +463,13 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
464463
assert cloned_chat._messages == chat._messages
465464

466465

466+
# Also check coroutine support in the asynchronous API
467+
# (this becomes a regular sync tool in the sync API tests)
468+
def divide(numerator: float, denominator: float) -> float:
469+
"""Divide the given numerator by the given denominator. Return the result."""
470+
return numerator / denominator
471+
472+
467473
@pytest.mark.lmstudio
468474
def test_tool_using_agent_error_handling_sync(caplog: LogCap) -> None:
469475
caplog.set_level(logging.DEBUG)

tests/test_inference.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,22 @@ def test_duplicate_tool_names_rejected() -> None:
138138
LMStudioValueError, match="Duplicate tool names are not permitted"
139139
):
140140
ChatResponseEndpoint.parse_tools(tools)
141+
142+
143+
async def example_async_tool() -> int:
144+
"""Example asynchronous tool definition"""
145+
return 42
146+
147+
148+
def test_async_tool_rejected() -> None:
149+
tools: list[Any] = [example_async_tool]
150+
with pytest.raises(LMStudioValueError, match=".*example_async_tool.*not supported"):
151+
ChatResponseEndpoint.parse_tools(tools)
152+
153+
154+
def test_async_tool_accepted() -> None:
155+
tools: list[Any] = [example_async_tool]
156+
llm_tools, client_map = ChatResponseEndpoint.parse_tools(tools, allow_async=True)
157+
assert llm_tools.tools is not None
158+
assert len(llm_tools.tools) == 1
159+
assert len(client_map) == 1

0 commit comments

Comments
 (0)