Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 218 additions & 3 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Async I/O protocol implementation for the LM Studio remote access API."""

import asyncio
import itertools
import time

from abc import abstractmethod
from contextlib import AsyncExitStack, asynccontextmanager
Expand Down Expand Up @@ -28,16 +30,27 @@

from httpx_ws import AsyncWebSocketSession

from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async
from .sdk_api import (
LMStudioRuntimeError,
LMStudioValueError,
sdk_callback_invocation,
sdk_public_api,
sdk_public_api_async,
)
from .schemas import AnyLMStudioStruct, DictObject
from .history import (
AssistantResponse,
Chat,
ChatHistoryDataDict,
FileHandle,
LocalFileInput,
_LocalFileData,
ToolCallResultData,
ToolCallRequest,
ToolResultMessage,
)
from .json_api import (
ActResult,
AnyLoadConfig,
AnyModelSpecifier,
AvailableModelBase,
Expand All @@ -63,6 +76,7 @@
LlmPredictionFragment,
LMStudioCancelledError,
LMStudioClientError,
LMStudioPredictionError,
LMStudioWebsocket,
LMStudioWebsocketError,
LoadModelEndpoint,
Expand All @@ -80,11 +94,15 @@
PredictionFragmentEvent,
PredictionMessageCallback,
PredictionResult,
PredictionRoundResult,
PredictionRxEvent,
PredictionToolCallEvent,
PromptProcessingCallback,
RemoteCallHandler,
ResponseSchema,
SendMessageAsync,
TModelInfo,
ToolDefinition,
check_model_namespace,
load_struct,
_model_spec_to_api_dict,
Expand Down Expand Up @@ -895,14 +913,18 @@ async def __aexit__(
await self.aclose()

async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]:
async for event in self._iter_events():
if isinstance(event, PredictionFragmentEvent):
yield event.arg

async def _iter_events(self) -> AsyncIterator[PredictionRxEvent]:
endpoint = self._endpoint
async with self:
assert self._channel is not None
async for contents in self._channel.rx_stream():
for event in endpoint.iter_message_events(contents):
endpoint.handle_rx_event(event)
if isinstance(event, PredictionFragmentEvent):
yield event.arg
yield event
if endpoint.is_finished:
break
self._mark_finished()
Expand Down Expand Up @@ -1244,6 +1266,199 @@ async def respond(
pass
return prediction_stream.result()

# TODO: Improve code sharing between sync and async multi-round predictions
# TODO: Accept async tools in the tools iterable
@sdk_public_api_async()
async def act(
self,
chat: Chat | ChatHistoryDataDict | str,
tools: Iterable[ToolDefinition],
*,
max_prediction_rounds: int | None = None,
max_parallel_tool_calls: int | None = 1,
config: LlmPredictionConfig | LlmPredictionConfigDict | None = None,
preset: str | None = None,
on_message: Callable[[AssistantResponse | ToolResultMessage], Any]
| None = None,
on_first_token: Callable[[int], Any] | None = None,
on_prediction_fragment: Callable[[LlmPredictionFragment, int], Any]
| None = None,
on_round_start: Callable[[int], Any] | None = None,
on_round_end: Callable[[int], Any] | None = None,
on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None,
on_prompt_processing_progress: Callable[[float, int], Any] | None = None,
handle_invalid_tool_request: Callable[
[LMStudioPredictionError, ToolCallRequest | None], str | None
]
| None = None,
) -> ActResult:
"""Request a response (with implicit tool use) in an ongoing agent chat session.

Note: details of configuration fields may change in SDK feature releases.
"""
start_time = time.perf_counter()
# It is not yet possible to combine tool calling with requests for structured responses
response_format = None
agent_chat: Chat = Chat.from_history(chat)
del chat # Avoid any further access to the input chat history
# Multiple rounds, until all tool calls are resolved or limit is reached
round_counter: Iterable[int]
if max_prediction_rounds is not None:
if max_prediction_rounds < 1:
raise LMStudioValueError(
f"Max prediction rounds must be at least 1 ({max_prediction_rounds!r} given)"
)
round_counter = range(max_prediction_rounds)
final_round_index = max_prediction_rounds - 1
else:
# Do not force a final round when no limit is specified
final_round_index = -1
round_counter = itertools.count()
llm_tool_args = ChatResponseEndpoint.parse_tools(tools)
del tools
# Supply the round index to any endpoint callbacks that expect one
round_index: int
on_first_token_for_endpoint: PredictionFirstTokenCallback | None = None
if on_first_token is not None:

def _wrapped_on_first_token() -> None:
assert on_first_token is not None
on_first_token(round_index)

on_first_token_for_endpoint = _wrapped_on_first_token
on_prediction_fragment_for_endpoint: PredictionFragmentCallback | None = None
if on_prediction_fragment is not None:

def _wrapped_on_prediction_fragment(
fragment: LlmPredictionFragment,
) -> None:
assert on_prediction_fragment is not None
on_prediction_fragment(fragment, round_index)

on_prediction_fragment_for_endpoint = _wrapped_on_prediction_fragment
on_prompt_processing_for_endpoint: PromptProcessingCallback | None = None
if on_prompt_processing_progress is not None:

def _wrapped_on_prompt_processing_progress(progress: float) -> None:
assert on_prompt_processing_progress is not None
on_prompt_processing_progress(progress, round_index)

on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
# TODO: Implementation to this point is common between the sync and async APIs
# Implementation past this point differs (as the sync API uses its own thread pool)

# Request predictions until no more tool call requests are received in response
# (or the maximum number of prediction rounds is reached)
for round_index in round_counter:
self._logger.debug(
"Starting .act() prediction round", round_index=round_index
)
if on_round_start is not None:
err_msg = f"Round start callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
on_round_start(round_index)
# Update the endpoint definition on each iteration in order to:
# * update the chat history with the previous round result
# * be able to disallow tool use when the rounds are limited
# TODO: Refactor endpoint API to avoid repeatedly performing the
# LlmPredictionConfig -> KvConfigStack transformation
endpoint = ChatResponseEndpoint(
self.identifier,
agent_chat,
response_format,
config,
preset,
None, # on_message is invoked directly
on_first_token_for_endpoint,
on_prediction_fragment_for_endpoint,
on_prompt_processing_for_endpoint,
handle_invalid_tool_request,
*(llm_tool_args if round_index != final_round_index else (None, None)),
)
channel_cm = self._session._create_channel(endpoint)
prediction_stream = AsyncPredictionStream(channel_cm, endpoint)
tool_call_requests: list[ToolCallRequest] = []
parsed_tool_calls: list[Callable[[], ToolCallResultData]] = []
async for event in prediction_stream._iter_events():
if isinstance(event, PredictionToolCallEvent):
tool_call_request = event.arg
tool_call_requests.append(tool_call_request)
# TODO: Also handle async tool calls here
tool_call = endpoint.request_tool_call(tool_call_request)
parsed_tool_calls.append(tool_call)
prediction = prediction_stream.result()
self._logger.debug(
"Completed .act() prediction round", round_index=round_index
)
if on_prediction_completed:
round_result = PredictionRoundResult.from_result(
prediction, round_index
)
err_msg = f"Prediction completed callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
on_prediction_completed(round_result)
if parsed_tool_calls:
if max_parallel_tool_calls is None:
max_parallel_tool_calls = len(parsed_tool_calls)
active_tool_calls = 0
tool_call_futures: list[asyncio.Future[ToolCallResultData]] = []
for tool_call in parsed_tool_calls:
if active_tool_calls >= max_parallel_tool_calls:
# Wait for a previous call to finish before starting another one
_done, pending = await asyncio.wait(
tool_call_futures, return_when=asyncio.FIRST_COMPLETED
)
active_tool_calls = len(pending)
# TODO: Also handle async tool calls here
tool_call_futures.append(
asyncio.ensure_future(asyncio.to_thread(tool_call))
)
active_tool_calls += 1
tool_call_results: list[ToolCallResultData] = []
for tool_call_request, tool_call_future in zip(
tool_call_requests, tool_call_futures
):
try:
await tool_call_future
except Exception as exc:
tool_call_result = endpoint._handle_failed_tool_request(
exc, tool_call_request
)
else:
tool_call_result = tool_call_future.result()
tool_call_results.append(tool_call_result)
requests_message = agent_chat.add_assistant_response(
prediction, tool_call_requests
)
results_message = agent_chat.add_tool_results(tool_call_results)
if on_message is not None:
err_msg = f"Tool request message callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
on_message(requests_message)
err_msg = f"Tool result message callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
on_message(results_message)
elif on_message is not None:
err_msg = f"Final response message callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
on_message(agent_chat.add_assistant_response(prediction))
if on_round_end is not None:
err_msg = f"Round end callback failed for {self!r}"
with sdk_callback_invocation(err_msg, self._logger):
on_round_end(round_index)
if not tool_call_requests:
# No tool call requests -> we're done here
break
if round_index == final_round_index:
# We somehow received at least one tool call request,
# even though tools are omitted on the final round
err_msg = "Model requested tool use on final prediction round."
endpoint._handle_invalid_tool_request(err_msg)
break
num_rounds = round_index + 1
duration = time.perf_counter() - start_time
return ActResult(rounds=num_rounds, total_time_seconds=duration)

@sdk_public_api_async()
async def apply_prompt_template(
self,
Expand Down
11 changes: 7 additions & 4 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@
from .schemas import AnyLMStudioStruct, DictObject
from .history import (
AssistantResponse,
ToolResultMessage,
Chat,
ChatHistoryDataDict,
FileHandle,
LocalFileInput,
_LocalFileData,
ToolCallRequest,
ToolResultMessage,
)
from .json_api import (
ActResult,
Expand Down Expand Up @@ -1316,9 +1316,8 @@ def respond(
pass
return prediction_stream.result()

# Multi-round predictions are currently a sync-only handle-only feature
# TODO: Refactor to allow for more code sharing with the async API
# with defined aliases for the expected callback signatures
# TODO: Improve code sharing between sync and async multi-round predictions
# TODO: Accept async tools in the tools iterable
@sdk_public_api()
def act(
self,
Expand Down Expand Up @@ -1395,6 +1394,9 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
on_prompt_processing_progress(progress, round_index)

on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress
# TODO: Implementation to this point is common between the sync and async APIs
# Implementation past this point differs (as the async API uses the loop's executor)

# Request predictions until no more tool call requests are received in response
# (or the maximum number of prediction rounds is reached)
with ThreadPoolExecutor(max_parallel_tool_calls) as pool:
Expand Down Expand Up @@ -1436,6 +1438,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
if isinstance(event, PredictionToolCallEvent):
tool_call_request = event.arg
tool_call_requests.append(tool_call_request)
# TODO: Also handle async tool calls here
tool_call = endpoint.request_tool_call(tool_call_request)
pending_tool_calls[pool.submit(tool_call)] = tool_call_request
prediction = prediction_stream.result()
Expand Down
Loading