|
1 | 1 | """Async I/O protocol implementation for the LM Studio remote access API.""" |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import itertools |
| 5 | +import time |
4 | 6 |
|
5 | 7 | from abc import abstractmethod |
6 | 8 | from contextlib import AsyncExitStack, asynccontextmanager |
|
28 | 30 |
|
29 | 31 | from httpx_ws import AsyncWebSocketSession |
30 | 32 |
|
31 | | -from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async |
| 33 | +from .sdk_api import ( |
| 34 | + LMStudioRuntimeError, |
| 35 | + LMStudioValueError, |
| 36 | + sdk_callback_invocation, |
| 37 | + sdk_public_api, |
| 38 | + sdk_public_api_async, |
| 39 | +) |
32 | 40 | from .schemas import AnyLMStudioStruct, DictObject |
33 | 41 | from .history import ( |
| 42 | + AssistantResponse, |
34 | 43 | Chat, |
35 | 44 | ChatHistoryDataDict, |
36 | 45 | FileHandle, |
37 | 46 | LocalFileInput, |
38 | 47 | _LocalFileData, |
| 48 | + ToolCallResultData, |
| 49 | + ToolCallRequest, |
| 50 | + ToolResultMessage, |
39 | 51 | ) |
40 | 52 | from .json_api import ( |
| 53 | + ActResult, |
41 | 54 | AnyLoadConfig, |
42 | 55 | AnyModelSpecifier, |
43 | 56 | AvailableModelBase, |
|
63 | 76 | LlmPredictionFragment, |
64 | 77 | LMStudioCancelledError, |
65 | 78 | LMStudioClientError, |
| 79 | + LMStudioPredictionError, |
66 | 80 | LMStudioWebsocket, |
67 | 81 | LMStudioWebsocketError, |
68 | 82 | LoadModelEndpoint, |
|
80 | 94 | PredictionFragmentEvent, |
81 | 95 | PredictionMessageCallback, |
82 | 96 | PredictionResult, |
| 97 | + PredictionRoundResult, |
| 98 | + PredictionRxEvent, |
| 99 | + PredictionToolCallEvent, |
83 | 100 | PromptProcessingCallback, |
84 | 101 | RemoteCallHandler, |
85 | 102 | ResponseSchema, |
86 | 103 | SendMessageAsync, |
87 | 104 | TModelInfo, |
| 105 | + ToolDefinition, |
88 | 106 | check_model_namespace, |
89 | 107 | load_struct, |
90 | 108 | _model_spec_to_api_dict, |
@@ -895,14 +913,18 @@ async def __aexit__( |
895 | 913 | await self.aclose() |
896 | 914 |
|
897 | 915 | async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]: |
| 916 | + async for event in self._iter_events(): |
| 917 | + if isinstance(event, PredictionFragmentEvent): |
| 918 | + yield event.arg |
| 919 | + |
| 920 | + async def _iter_events(self) -> AsyncIterator[PredictionRxEvent]: |
898 | 921 | endpoint = self._endpoint |
899 | 922 | async with self: |
900 | 923 | assert self._channel is not None |
901 | 924 | async for contents in self._channel.rx_stream(): |
902 | 925 | for event in endpoint.iter_message_events(contents): |
903 | 926 | endpoint.handle_rx_event(event) |
904 | | - if isinstance(event, PredictionFragmentEvent): |
905 | | - yield event.arg |
| 927 | + yield event |
906 | 928 | if endpoint.is_finished: |
907 | 929 | break |
908 | 930 | self._mark_finished() |
@@ -1244,6 +1266,199 @@ async def respond( |
1244 | 1266 | pass |
1245 | 1267 | return prediction_stream.result() |
1246 | 1268 |
|
| 1269 | + # TODO: Improve code sharing between sync and async multi-round predictions |
| 1270 | + # TODO: Accept async tools in the tools iterable |
| 1271 | + @sdk_public_api_async() |
| 1272 | + async def act( |
| 1273 | + self, |
| 1274 | + chat: Chat | ChatHistoryDataDict | str, |
| 1275 | + tools: Iterable[ToolDefinition], |
| 1276 | + *, |
| 1277 | + max_prediction_rounds: int | None = None, |
| 1278 | + max_parallel_tool_calls: int | None = 1, |
| 1279 | + config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, |
| 1280 | + preset: str | None = None, |
| 1281 | + on_message: Callable[[AssistantResponse | ToolResultMessage], Any] |
| 1282 | + | None = None, |
| 1283 | + on_first_token: Callable[[int], Any] | None = None, |
| 1284 | + on_prediction_fragment: Callable[[LlmPredictionFragment, int], Any] |
| 1285 | + | None = None, |
| 1286 | + on_round_start: Callable[[int], Any] | None = None, |
| 1287 | + on_round_end: Callable[[int], Any] | None = None, |
| 1288 | + on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None, |
| 1289 | + on_prompt_processing_progress: Callable[[float, int], Any] | None = None, |
| 1290 | + handle_invalid_tool_request: Callable[ |
| 1291 | + [LMStudioPredictionError, ToolCallRequest | None], str | None |
| 1292 | + ] |
| 1293 | + | None = None, |
| 1294 | + ) -> ActResult: |
| 1295 | + """Request a response (with implicit tool use) in an ongoing agent chat session. |
| 1296 | +
|
| 1297 | + Note: details of configuration fields may change in SDK feature releases. |
| 1298 | + """ |
| 1299 | + start_time = time.perf_counter() |
| 1300 | + # It is not yet possible to combine tool calling with requests for structured responses |
| 1301 | + response_format = None |
| 1302 | + agent_chat: Chat = Chat.from_history(chat) |
| 1303 | + del chat # Avoid any further access to the input chat history |
| 1304 | + # Multiple rounds, until all tool calls are resolved or limit is reached |
| 1305 | + round_counter: Iterable[int] |
| 1306 | + if max_prediction_rounds is not None: |
| 1307 | + if max_prediction_rounds < 1: |
| 1308 | + raise LMStudioValueError( |
| 1309 | + f"Max prediction rounds must be at least 1 ({max_prediction_rounds!r} given)" |
| 1310 | + ) |
| 1311 | + round_counter = range(max_prediction_rounds) |
| 1312 | + final_round_index = max_prediction_rounds - 1 |
| 1313 | + else: |
| 1314 | + # Do not force a final round when no limit is specified |
| 1315 | + final_round_index = -1 |
| 1316 | + round_counter = itertools.count() |
| 1317 | + llm_tool_args = ChatResponseEndpoint.parse_tools(tools) |
| 1318 | + del tools |
| 1319 | + # Supply the round index to any endpoint callbacks that expect one |
| 1320 | + round_index: int |
| 1321 | + on_first_token_for_endpoint: PredictionFirstTokenCallback | None = None |
| 1322 | + if on_first_token is not None: |
| 1323 | + |
| 1324 | + def _wrapped_on_first_token() -> None: |
| 1325 | + assert on_first_token is not None |
| 1326 | + on_first_token(round_index) |
| 1327 | + |
| 1328 | + on_first_token_for_endpoint = _wrapped_on_first_token |
| 1329 | + on_prediction_fragment_for_endpoint: PredictionFragmentCallback | None = None |
| 1330 | + if on_prediction_fragment is not None: |
| 1331 | + |
| 1332 | + def _wrapped_on_prediction_fragment( |
| 1333 | + fragment: LlmPredictionFragment, |
| 1334 | + ) -> None: |
| 1335 | + assert on_prediction_fragment is not None |
| 1336 | + on_prediction_fragment(fragment, round_index) |
| 1337 | + |
| 1338 | + on_prediction_fragment_for_endpoint = _wrapped_on_prediction_fragment |
| 1339 | + on_prompt_processing_for_endpoint: PromptProcessingCallback | None = None |
| 1340 | + if on_prompt_processing_progress is not None: |
| 1341 | + |
| 1342 | + def _wrapped_on_prompt_processing_progress(progress: float) -> None: |
| 1343 | + assert on_prompt_processing_progress is not None |
| 1344 | + on_prompt_processing_progress(progress, round_index) |
| 1345 | + |
| 1346 | + on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress |
| 1347 | + # TODO: Implementation to this point is common between the sync and async APIs |
| 1348 | + # Implementation past this point differs (as the sync API uses its own thread pool) |
| 1349 | + |
| 1350 | + # Request predictions until no more tool call requests are received in response |
| 1351 | + # (or the maximum number of prediction rounds is reached) |
| 1352 | + for round_index in round_counter: |
| 1353 | + self._logger.debug( |
| 1354 | + "Starting .act() prediction round", round_index=round_index |
| 1355 | + ) |
| 1356 | + if on_round_start is not None: |
| 1357 | + err_msg = f"Round start callback failed for {self!r}" |
| 1358 | + with sdk_callback_invocation(err_msg, self._logger): |
| 1359 | + on_round_start(round_index) |
| 1360 | + # Update the endpoint definition on each iteration in order to: |
| 1361 | + # * update the chat history with the previous round result |
| 1362 | + # * be able to disallow tool use when the rounds are limited |
| 1363 | + # TODO: Refactor endpoint API to avoid repeatedly performing the |
| 1364 | + # LlmPredictionConfig -> KvConfigStack transformation |
| 1365 | + endpoint = ChatResponseEndpoint( |
| 1366 | + self.identifier, |
| 1367 | + agent_chat, |
| 1368 | + response_format, |
| 1369 | + config, |
| 1370 | + preset, |
| 1371 | + None, # on_message is invoked directly |
| 1372 | + on_first_token_for_endpoint, |
| 1373 | + on_prediction_fragment_for_endpoint, |
| 1374 | + on_prompt_processing_for_endpoint, |
| 1375 | + handle_invalid_tool_request, |
| 1376 | + *(llm_tool_args if round_index != final_round_index else (None, None)), |
| 1377 | + ) |
| 1378 | + channel_cm = self._session._create_channel(endpoint) |
| 1379 | + prediction_stream = AsyncPredictionStream(channel_cm, endpoint) |
| 1380 | + tool_call_requests: list[ToolCallRequest] = [] |
| 1381 | + parsed_tool_calls: list[Callable[[], ToolCallResultData]] = [] |
| 1382 | + async for event in prediction_stream._iter_events(): |
| 1383 | + if isinstance(event, PredictionToolCallEvent): |
| 1384 | + tool_call_request = event.arg |
| 1385 | + tool_call_requests.append(tool_call_request) |
| 1386 | + # TODO: Also handle async tool calls here |
| 1387 | + tool_call = endpoint.request_tool_call(tool_call_request) |
| 1388 | + parsed_tool_calls.append(tool_call) |
| 1389 | + prediction = prediction_stream.result() |
| 1390 | + self._logger.debug( |
| 1391 | + "Completed .act() prediction round", round_index=round_index |
| 1392 | + ) |
| 1393 | + if on_prediction_completed: |
| 1394 | + round_result = PredictionRoundResult.from_result( |
| 1395 | + prediction, round_index |
| 1396 | + ) |
| 1397 | + err_msg = f"Prediction completed callback failed for {self!r}" |
| 1398 | + with sdk_callback_invocation(err_msg, self._logger): |
| 1399 | + on_prediction_completed(round_result) |
| 1400 | + if parsed_tool_calls: |
| 1401 | + if max_parallel_tool_calls is None: |
| 1402 | + max_parallel_tool_calls = len(parsed_tool_calls) |
| 1403 | + active_tool_calls = 0 |
| 1404 | + tool_call_futures: list[asyncio.Future[ToolCallResultData]] = [] |
| 1405 | + for tool_call in parsed_tool_calls: |
| 1406 | + if active_tool_calls >= max_parallel_tool_calls: |
| 1407 | + # Wait for a previous call to finish before starting another one |
| 1408 | + _done, pending = await asyncio.wait( |
| 1409 | + tool_call_futures, return_when=asyncio.FIRST_COMPLETED |
| 1410 | + ) |
| 1411 | + active_tool_calls = len(pending) |
| 1412 | + # TODO: Also handle async tool calls here |
| 1413 | + tool_call_futures.append( |
| 1414 | + asyncio.ensure_future(asyncio.to_thread(tool_call)) |
| 1415 | + ) |
| 1416 | + active_tool_calls += 1 |
| 1417 | + tool_call_results: list[ToolCallResultData] = [] |
| 1418 | + for tool_call_request, tool_call_future in zip( |
| 1419 | + tool_call_requests, tool_call_futures |
| 1420 | + ): |
| 1421 | + try: |
| 1422 | + await tool_call_future |
| 1423 | + except Exception as exc: |
| 1424 | + tool_call_result = endpoint._handle_failed_tool_request( |
| 1425 | + exc, tool_call_request |
| 1426 | + ) |
| 1427 | + else: |
| 1428 | + tool_call_result = tool_call_future.result() |
| 1429 | + tool_call_results.append(tool_call_result) |
| 1430 | + requests_message = agent_chat.add_assistant_response( |
| 1431 | + prediction, tool_call_requests |
| 1432 | + ) |
| 1433 | + results_message = agent_chat.add_tool_results(tool_call_results) |
| 1434 | + if on_message is not None: |
| 1435 | + err_msg = f"Tool request message callback failed for {self!r}" |
| 1436 | + with sdk_callback_invocation(err_msg, self._logger): |
| 1437 | + on_message(requests_message) |
| 1438 | + err_msg = f"Tool result message callback failed for {self!r}" |
| 1439 | + with sdk_callback_invocation(err_msg, self._logger): |
| 1440 | + on_message(results_message) |
| 1441 | + elif on_message is not None: |
| 1442 | + err_msg = f"Final response message callback failed for {self!r}" |
| 1443 | + with sdk_callback_invocation(err_msg, self._logger): |
| 1444 | + on_message(agent_chat.add_assistant_response(prediction)) |
| 1445 | + if on_round_end is not None: |
| 1446 | + err_msg = f"Round end callback failed for {self!r}" |
| 1447 | + with sdk_callback_invocation(err_msg, self._logger): |
| 1448 | + on_round_end(round_index) |
| 1449 | + if not tool_call_requests: |
| 1450 | + # No tool call requests -> we're done here |
| 1451 | + break |
| 1452 | + if round_index == final_round_index: |
| 1453 | + # We somehow received at least one tool call request, |
| 1454 | + # even though tools are omitted on the final round |
| 1455 | + err_msg = "Model requested tool use on final prediction round." |
| 1456 | + endpoint._handle_invalid_tool_request(err_msg) |
| 1457 | + break |
| 1458 | + num_rounds = round_index + 1 |
| 1459 | + duration = time.perf_counter() - start_time |
| 1460 | + return ActResult(rounds=num_rounds, total_time_seconds=duration) |
| 1461 | + |
1247 | 1462 | @sdk_public_api_async() |
1248 | 1463 | async def apply_prompt_template( |
1249 | 1464 | self, |
|
0 commit comments