Skip to content

Commit 65a510b

Browse files
authored
Add .act() to the async API (#132)
1 parent 70331ff commit 65a510b

File tree

6 files changed

+509
-160
lines changed

6 files changed

+509
-160
lines changed

src/lmstudio/async_api.py

Lines changed: 218 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Async I/O protocol implementation for the LM Studio remote access API."""
22

33
import asyncio
4+
import itertools
5+
import time
46

57
from abc import abstractmethod
68
from contextlib import AsyncExitStack, asynccontextmanager
@@ -28,16 +30,27 @@
2830

2931
from httpx_ws import AsyncWebSocketSession
3032

31-
from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async
33+
from .sdk_api import (
34+
LMStudioRuntimeError,
35+
LMStudioValueError,
36+
sdk_callback_invocation,
37+
sdk_public_api,
38+
sdk_public_api_async,
39+
)
3240
from .schemas import AnyLMStudioStruct, DictObject
3341
from .history import (
42+
AssistantResponse,
3443
Chat,
3544
ChatHistoryDataDict,
3645
FileHandle,
3746
LocalFileInput,
3847
_LocalFileData,
48+
ToolCallResultData,
49+
ToolCallRequest,
50+
ToolResultMessage,
3951
)
4052
from .json_api import (
53+
ActResult,
4154
AnyLoadConfig,
4255
AnyModelSpecifier,
4356
AvailableModelBase,
@@ -63,6 +76,7 @@
6376
LlmPredictionFragment,
6477
LMStudioCancelledError,
6578
LMStudioClientError,
79+
LMStudioPredictionError,
6680
LMStudioWebsocket,
6781
LMStudioWebsocketError,
6882
LoadModelEndpoint,
@@ -80,11 +94,15 @@
8094
PredictionFragmentEvent,
8195
PredictionMessageCallback,
8296
PredictionResult,
97+
PredictionRoundResult,
98+
PredictionRxEvent,
99+
PredictionToolCallEvent,
83100
PromptProcessingCallback,
84101
RemoteCallHandler,
85102
ResponseSchema,
86103
SendMessageAsync,
87104
TModelInfo,
105+
ToolDefinition,
88106
check_model_namespace,
89107
load_struct,
90108
_model_spec_to_api_dict,
@@ -895,14 +913,18 @@ async def __aexit__(
895913
await self.aclose()
896914

897915
async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]:
916+
async for event in self._iter_events():
917+
if isinstance(event, PredictionFragmentEvent):
918+
yield event.arg
919+
920+
async def _iter_events(self) -> AsyncIterator[PredictionRxEvent]:
898921
endpoint = self._endpoint
899922
async with self:
900923
assert self._channel is not None
901924
async for contents in self._channel.rx_stream():
902925
for event in endpoint.iter_message_events(contents):
903926
endpoint.handle_rx_event(event)
904-
if isinstance(event, PredictionFragmentEvent):
905-
yield event.arg
927+
yield event
906928
if endpoint.is_finished:
907929
break
908930
self._mark_finished()
@@ -1244,6 +1266,199 @@ async def respond(
12441266
pass
12451267
return prediction_stream.result()
12461268

1269+
# TODO: Improve code sharing between sync and async multi-round predictions
1270+
# TODO: Accept async tools in the tools iterable
1271+
@sdk_public_api_async()
1272+
async def act(
1273+
self,
1274+
chat: Chat | ChatHistoryDataDict | str,
1275+
tools: Iterable[ToolDefinition],
1276+
*,
1277+
max_prediction_rounds: int | None = None,
1278+
max_parallel_tool_calls: int | None = 1,
1279+
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
1280+
preset: str | None = None,
1281+
on_message: Callable[[AssistantResponse | ToolResultMessage], Any]
1282+
| None = None,
1283+
on_first_token: Callable[[int], Any] | None = None,
1284+
on_prediction_fragment: Callable[[LlmPredictionFragment, int], Any]
1285+
| None = None,
1286+
on_round_start: Callable[[int], Any] | None = None,
1287+
on_round_end: Callable[[int], Any] | None = None,
1288+
on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None,
1289+
on_prompt_processing_progress: Callable[[float, int], Any] | None = None,
1290+
handle_invalid_tool_request: Callable[
1291+
[LMStudioPredictionError, ToolCallRequest | None], str | None
1292+
]
1293+
| None = None,
1294+
) -> ActResult:
1295+
"""Request a response (with implicit tool use) in an ongoing agent chat session.
1296+
1297+
Note: details of configuration fields may change in SDK feature releases.
1298+
"""
1299+
start_time = time.perf_counter()
1300+
# It is not yet possible to combine tool calling with requests for structured responses
1301+
response_format = None
1302+
agent_chat: Chat = Chat.from_history(chat)
1303+
del chat # Avoid any further access to the input chat history
1304+
# Multiple rounds, until all tool calls are resolved or limit is reached
1305+
round_counter: Iterable[int]
1306+
if max_prediction_rounds is not None:
1307+
if max_prediction_rounds < 1:
1308+
raise LMStudioValueError(
1309+
f"Max prediction rounds must be at least 1 ({max_prediction_rounds!r} given)"
1310+
)
1311+
round_counter = range(max_prediction_rounds)
1312+
final_round_index = max_prediction_rounds - 1
1313+
else:
1314+
# Do not force a final round when no limit is specified
1315+
final_round_index = -1
1316+
round_counter = itertools.count()
1317+
llm_tool_args = ChatResponseEndpoint.parse_tools(tools)
1318+
del tools
1319+
# Supply the round index to any endpoint callbacks that expect one
1320+
round_index: int
1321+
on_first_token_for_endpoint: PredictionFirstTokenCallback | None = None
1322+
if on_first_token is not None:
1323+
1324+
def _wrapped_on_first_token() -> None:
1325+
assert on_first_token is not None
1326+
on_first_token(round_index)
1327+
1328+
on_first_token_for_endpoint = _wrapped_on_first_token
1329+
on_prediction_fragment_for_endpoint: PredictionFragmentCallback | None = None
1330+
if on_prediction_fragment is not None:
1331+
1332+
def _wrapped_on_prediction_fragment(
1333+
fragment: LlmPredictionFragment,
1334+
) -> None:
1335+
assert on_prediction_fragment is not None
1336+
on_prediction_fragment(fragment, round_index)
1337+
1338+
on_prediction_fragment_for_endpoint = _wrapped_on_prediction_fragment
1339+
on_prompt_processing_for_endpoint: PromptProcessingCallback | None = None
1340+
if on_prompt_processing_progress is not None:
1341+
1342+
def _wrapped_on_prompt_processing_progress(progress: float) -> None:
1343+
assert on_prompt_processing_progress is not None
1344+
on_prompt_processing_progress(progress, round_index)
1345+
1346+
on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
1347+
# TODO: Implementation to this point is common between the sync and async APIs
1348+
# Implementation past this point differs (as the sync API uses its own thread pool)
1349+
1350+
# Request predictions until no more tool call requests are received in response
1351+
# (or the maximum number of prediction rounds is reached)
1352+
for round_index in round_counter:
1353+
self._logger.debug(
1354+
"Starting .act() prediction round", round_index=round_index
1355+
)
1356+
if on_round_start is not None:
1357+
err_msg = f"Round start callback failed for {self!r}"
1358+
with sdk_callback_invocation(err_msg, self._logger):
1359+
on_round_start(round_index)
1360+
# Update the endpoint definition on each iteration in order to:
1361+
# * update the chat history with the previous round result
1362+
# * be able to disallow tool use when the rounds are limited
1363+
# TODO: Refactor endpoint API to avoid repeatedly performing the
1364+
# LlmPredictionConfig -> KvConfigStack transformation
1365+
endpoint = ChatResponseEndpoint(
1366+
self.identifier,
1367+
agent_chat,
1368+
response_format,
1369+
config,
1370+
preset,
1371+
None, # on_message is invoked directly
1372+
on_first_token_for_endpoint,
1373+
on_prediction_fragment_for_endpoint,
1374+
on_prompt_processing_for_endpoint,
1375+
handle_invalid_tool_request,
1376+
*(llm_tool_args if round_index != final_round_index else (None, None)),
1377+
)
1378+
channel_cm = self._session._create_channel(endpoint)
1379+
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
1380+
tool_call_requests: list[ToolCallRequest] = []
1381+
parsed_tool_calls: list[Callable[[], ToolCallResultData]] = []
1382+
async for event in prediction_stream._iter_events():
1383+
if isinstance(event, PredictionToolCallEvent):
1384+
tool_call_request = event.arg
1385+
tool_call_requests.append(tool_call_request)
1386+
# TODO: Also handle async tool calls here
1387+
tool_call = endpoint.request_tool_call(tool_call_request)
1388+
parsed_tool_calls.append(tool_call)
1389+
prediction = prediction_stream.result()
1390+
self._logger.debug(
1391+
"Completed .act() prediction round", round_index=round_index
1392+
)
1393+
if on_prediction_completed:
1394+
round_result = PredictionRoundResult.from_result(
1395+
prediction, round_index
1396+
)
1397+
err_msg = f"Prediction completed callback failed for {self!r}"
1398+
with sdk_callback_invocation(err_msg, self._logger):
1399+
on_prediction_completed(round_result)
1400+
if parsed_tool_calls:
1401+
if max_parallel_tool_calls is None:
1402+
max_parallel_tool_calls = len(parsed_tool_calls)
1403+
active_tool_calls = 0
1404+
tool_call_futures: list[asyncio.Future[ToolCallResultData]] = []
1405+
for tool_call in parsed_tool_calls:
1406+
if active_tool_calls >= max_parallel_tool_calls:
1407+
# Wait for a previous call to finish before starting another one
1408+
_done, pending = await asyncio.wait(
1409+
tool_call_futures, return_when=asyncio.FIRST_COMPLETED
1410+
)
1411+
active_tool_calls = len(pending)
1412+
# TODO: Also handle async tool calls here
1413+
tool_call_futures.append(
1414+
asyncio.ensure_future(asyncio.to_thread(tool_call))
1415+
)
1416+
active_tool_calls += 1
1417+
tool_call_results: list[ToolCallResultData] = []
1418+
for tool_call_request, tool_call_future in zip(
1419+
tool_call_requests, tool_call_futures
1420+
):
1421+
try:
1422+
await tool_call_future
1423+
except Exception as exc:
1424+
tool_call_result = endpoint._handle_failed_tool_request(
1425+
exc, tool_call_request
1426+
)
1427+
else:
1428+
tool_call_result = tool_call_future.result()
1429+
tool_call_results.append(tool_call_result)
1430+
requests_message = agent_chat.add_assistant_response(
1431+
prediction, tool_call_requests
1432+
)
1433+
results_message = agent_chat.add_tool_results(tool_call_results)
1434+
if on_message is not None:
1435+
err_msg = f"Tool request message callback failed for {self!r}"
1436+
with sdk_callback_invocation(err_msg, self._logger):
1437+
on_message(requests_message)
1438+
err_msg = f"Tool result message callback failed for {self!r}"
1439+
with sdk_callback_invocation(err_msg, self._logger):
1440+
on_message(results_message)
1441+
elif on_message is not None:
1442+
err_msg = f"Final response message callback failed for {self!r}"
1443+
with sdk_callback_invocation(err_msg, self._logger):
1444+
on_message(agent_chat.add_assistant_response(prediction))
1445+
if on_round_end is not None:
1446+
err_msg = f"Round end callback failed for {self!r}"
1447+
with sdk_callback_invocation(err_msg, self._logger):
1448+
on_round_end(round_index)
1449+
if not tool_call_requests:
1450+
# No tool call requests -> we're done here
1451+
break
1452+
if round_index == final_round_index:
1453+
# We somehow received at least one tool call request,
1454+
# even though tools are omitted on the final round
1455+
err_msg = "Model requested tool use on final prediction round."
1456+
endpoint._handle_invalid_tool_request(err_msg)
1457+
break
1458+
num_rounds = round_index + 1
1459+
duration = time.perf_counter() - start_time
1460+
return ActResult(rounds=num_rounds, total_time_seconds=duration)
1461+
12471462
@sdk_public_api_async()
12481463
async def apply_prompt_template(
12491464
self,

src/lmstudio/sync_api.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
from .schemas import AnyLMStudioStruct, DictObject
4444
from .history import (
4545
AssistantResponse,
46-
ToolResultMessage,
4746
Chat,
4847
ChatHistoryDataDict,
4948
FileHandle,
5049
LocalFileInput,
5150
_LocalFileData,
5251
ToolCallRequest,
52+
ToolResultMessage,
5353
)
5454
from .json_api import (
5555
ActResult,
@@ -1316,9 +1316,8 @@ def respond(
13161316
pass
13171317
return prediction_stream.result()
13181318

1319-
# Multi-round predictions are currently a sync-only handle-only feature
1320-
# TODO: Refactor to allow for more code sharing with the async API
1321-
# with defined aliases for the expected callback signatures
1319+
# TODO: Improve code sharing between sync and async multi-round predictions
1320+
# TODO: Accept async tools in the tools iterable
13221321
@sdk_public_api()
13231322
def act(
13241323
self,
@@ -1395,6 +1394,9 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
13951394
on_prompt_processing_progress(progress, round_index)
13961395

13971396
on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
1397+
# TODO: Implementation to this point is common between the sync and async APIs
1398+
# Implementation past this point differs (as the async API uses the loop's executor)
1399+
13981400
# Request predictions until no more tool call requests are received in response
13991401
# (or the maximum number of prediction rounds is reached)
14001402
with ThreadPoolExecutor(max_parallel_tool_calls) as pool:
@@ -1436,6 +1438,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
14361438
if isinstance(event, PredictionToolCallEvent):
14371439
tool_call_request = event.arg
14381440
tool_call_requests.append(tool_call_request)
1441+
# TODO: Also handle async tool calls here
14391442
tool_call = endpoint.request_tool_call(tool_call_request)
14401443
pending_tool_calls[pool.submit(tool_call)] = tool_call_request
14411444
prediction = prediction_stream.result()

0 commit comments

Comments
 (0)