Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ markers = [
# Warnings should only be emitted when being specifically tested
filterwarnings = [
"error",
"ignore:.*the async API is not yet stable:FutureWarning"
]
# Capture log info from network client libraries
log_format = "%(asctime)s %(levelname)s %(message)s"
Expand Down
13 changes: 11 additions & 2 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

import asyncio

from concurrent.futures import Future as SyncFuture

# Python 3.10 compatibility: use concurrent.futures.TimeoutError instead of the builtin
# In 3.11+, these are the same type, in 3.10 futures have their own timeout exception
from concurrent.futures import Future as SyncFuture, TimeoutError as SyncFutureTimeout
from contextlib import AsyncExitStack, contextmanager
from functools import partial
from typing import (
Expand Down Expand Up @@ -47,6 +50,12 @@
# and omits the generalised features that the SDK doesn't need)
T = TypeVar("T")

__all__ = [
"SyncFutureTimeout",
"AsyncTaskManager",
"AsyncWebsocketHandler",
]


class AsyncTaskManager:
def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None:
Expand Down Expand Up @@ -429,7 +438,7 @@ def _rx_queue_get_threadsafe(self, rx_queue: RxQueue, timeout: float | None) ->
future = self._task_manager.run_coroutine_threadsafe(rx_queue.get())
try:
return future.result(timeout)
except TimeoutError:
except SyncFutureTimeout:
future.cancel()
raise

Expand Down
90 changes: 41 additions & 49 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Async I/O protocol implementation for the LM Studio remote access API."""

import asyncio
import warnings

from abc import abstractmethod
from contextlib import AsyncExitStack, asynccontextmanager
Expand Down Expand Up @@ -109,8 +108,8 @@
# and similar tasks is published from `json_api`.
# Bypassing the high level API, and working more
# directly with the underlying websocket(s) is
# supported (hence the public names), but they're
# not exported via the top-level `lmstudio` API.
# not supported due to the complexity of the task
# management details (hence the private names).
__all__ = [
"AnyAsyncDownloadedModel",
"AsyncClient",
Expand Down Expand Up @@ -215,7 +214,7 @@ async def receive_result(self) -> Any:
return self._rpc.handle_rx_message(message)


class AsyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebSocketSession]):
class _AsyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebSocketSession]):
"""Asynchronous websocket client that handles demultiplexing of reply messages."""

def __init__(
Expand Down Expand Up @@ -331,7 +330,7 @@ async def remote_call(
return await rpc.receive_result()


class AsyncSession(ClientSession["AsyncClient", AsyncLMStudioWebsocket]):
class _AsyncSession(ClientSession["AsyncClient", _AsyncLMStudioWebsocket]):
"""Async client session interfaces applicable to all API namespaces."""

def __init__(self, client: "AsyncClient") -> None:
Expand All @@ -354,7 +353,7 @@ async def __aexit__(self, *args: Any) -> None:
await self.disconnect()

@sdk_public_api_async()
async def connect(self) -> AsyncLMStudioWebsocket:
async def connect(self) -> _AsyncLMStudioWebsocket:
"""Connect the client session."""
self._fail_if_connected("Attempted to connect already connected session")
api_host = self._client.api_host
Expand All @@ -367,7 +366,7 @@ async def connect(self) -> AsyncLMStudioWebsocket:
resources = self._resource_manager
client = self._client
self._lmsws = lmsws = await resources.enter_async_context(
AsyncLMStudioWebsocket(
_AsyncLMStudioWebsocket(
client._task_manager, session_url, client._auth_details
)
)
Expand Down Expand Up @@ -411,7 +410,7 @@ async def remote_call(


TAsyncSessionModel = TypeVar(
"TAsyncSessionModel", bound="AsyncSessionModel[Any, Any, Any, Any]"
"TAsyncSessionModel", bound="_AsyncSessionModel[Any, Any, Any, Any]"
)
TAsyncModelHandle = TypeVar("TAsyncModelHandle", bound="AsyncModelHandle[Any]")

Expand Down Expand Up @@ -467,7 +466,7 @@ async def model(
class AsyncDownloadedEmbeddingModel(
AsyncDownloadedModel[
EmbeddingModelInfo,
"AsyncSessionEmbedding",
"_AsyncSessionEmbedding",
EmbeddingLoadModelConfig,
EmbeddingLoadModelConfigDict,
"AsyncEmbeddingModel",
Expand All @@ -476,7 +475,7 @@ class AsyncDownloadedEmbeddingModel(
"""Asynchronous download listing for an embedding model."""

def __init__(
self, model_info: DictObject, session: "AsyncSessionEmbedding"
self, model_info: DictObject, session: "_AsyncSessionEmbedding"
) -> None:
"""Initialize downloaded embedding model details."""
super().__init__(EmbeddingModelInfo, model_info, session)
Expand All @@ -485,23 +484,23 @@ def __init__(
class AsyncDownloadedLlm(
AsyncDownloadedModel[
LlmInfo,
"AsyncSessionLlm",
"_AsyncSessionLlm",
LlmLoadModelConfig,
LlmLoadModelConfigDict,
"AsyncLLM",
]
):
"""Asynchronous ownload listing for an LLM."""

def __init__(self, model_info: DictObject, session: "AsyncSessionLlm") -> None:
def __init__(self, model_info: DictObject, session: "_AsyncSessionLlm") -> None:
"""Initialize downloaded embedding model details."""
super().__init__(LlmInfo, model_info, session)


AnyAsyncDownloadedModel: TypeAlias = AsyncDownloadedModel[Any, Any, Any, Any, Any]


class AsyncSessionSystem(AsyncSession):
class _AsyncSessionSystem(_AsyncSession):
"""Async client session for the system namespace."""

API_NAMESPACE = "system"
Expand Down Expand Up @@ -531,7 +530,7 @@ def _process_download_listing(
)


class _AsyncSessionFiles(AsyncSession):
class _AsyncSessionFiles(_AsyncSession):
"""Async client session for the files namespace."""

API_NAMESPACE = "files"
Expand Down Expand Up @@ -562,7 +561,7 @@ async def prepare_image(
return await self._fetch_file_handle(file_data)


class AsyncModelDownloadOption(ModelDownloadOptionBase[AsyncSession]):
class AsyncModelDownloadOption(ModelDownloadOptionBase[_AsyncSession]):
"""A single download option for a model search result."""

@sdk_public_api_async()
Expand All @@ -577,10 +576,10 @@ async def download(
return await channel.wait_for_result()


class AsyncAvailableModel(AvailableModelBase[AsyncSession]):
class AsyncAvailableModel(AvailableModelBase[_AsyncSession]):
"""A model available for download from the model repository."""

_session: AsyncSession
_session: _AsyncSession

@sdk_public_api_async()
async def get_download_options(
Expand All @@ -595,7 +594,7 @@ async def get_download_options(
return final


class AsyncSessionRepository(AsyncSession):
class _AsyncSessionRepository(_AsyncSession):
"""Async client session for the repository namespace."""

API_NAMESPACE = "repository"
Expand All @@ -616,8 +615,8 @@ async def search_models(
TAsyncDownloadedModel = TypeVar("TAsyncDownloadedModel", bound=AnyAsyncDownloadedModel)


class AsyncSessionModel(
AsyncSession,
class _AsyncSessionModel(
_AsyncSession,
Generic[
TAsyncModelHandle,
TLoadConfig,
Expand All @@ -630,7 +629,7 @@ class AsyncSessionModel(
_API_TYPES: Type[ModelSessionTypes[TLoadConfig]]

@property
def _system_session(self) -> AsyncSessionSystem:
def _system_session(self) -> _AsyncSessionSystem:
return self._client.system

@property
Expand Down Expand Up @@ -922,8 +921,8 @@ async def cancel(self) -> None:
await self._channel.cancel()


class AsyncSessionLlm(
AsyncSessionModel[
class _AsyncSessionLlm(
_AsyncSessionModel[
"AsyncLLM",
LlmLoadModelConfig,
LlmLoadModelConfigDict,
Expand Down Expand Up @@ -1028,8 +1027,8 @@ async def _apply_prompt_template(
return response.get("formatted", "") if response else ""


class AsyncSessionEmbedding(
AsyncSessionModel[
class _AsyncSessionEmbedding(
_AsyncSessionModel[
"AsyncEmbeddingModel",
EmbeddingLoadModelConfig,
EmbeddingLoadModelConfigDict,
Expand Down Expand Up @@ -1115,7 +1114,7 @@ async def get_context_length(self) -> int:
AnyAsyncModel: TypeAlias = AsyncModelHandle[Any]


class AsyncLLM(AsyncModelHandle[AsyncSessionLlm]):
class AsyncLLM(AsyncModelHandle[_AsyncSessionLlm]):
"""Reference to a loaded LLM model."""

@sdk_public_api_async()
Expand Down Expand Up @@ -1258,7 +1257,7 @@ async def apply_prompt_template(
)


class AsyncEmbeddingModel(AsyncModelHandle[AsyncSessionEmbedding]):
class AsyncEmbeddingModel(AsyncModelHandle[_AsyncSessionEmbedding]):
"""Reference to a loaded embedding model."""

# Alas, type hints don't properly support distinguishing str vs Iterable[str]:
Expand All @@ -1271,24 +1270,17 @@ async def embed(
return await self._session._embed(self.identifier, input)


TAsyncSession = TypeVar("TAsyncSession", bound=AsyncSession)

_ASYNC_API_STABILITY_WARNING = """\
Note the async API is not yet stable and is expected to change in future releases
"""
TAsyncSession = TypeVar("TAsyncSession", bound=_AsyncSession)


class AsyncClient(ClientBase):
"""Async SDK client interface."""

def __init__(self, api_host: str | None = None) -> None:
"""Initialize API client."""
# Warn about the async API stability, since we expect it to change
# (in particular, accepting coroutine functions as callbacks)
warnings.warn(_ASYNC_API_STABILITY_WARNING, FutureWarning)
super().__init__(api_host)
self._resources = AsyncExitStack()
self._sessions: dict[str, AsyncSession] = {}
self._sessions: dict[str, _AsyncSession] = {}
self._task_manager = AsyncTaskManager()
# Unlike the sync API, we don't support GC-based resource
# management in the async API. Structured concurrency
Expand All @@ -1301,12 +1293,12 @@ def __init__(self, api_host: str | None = None) -> None:
# TODO: revisit lazy connections given the task manager implementation
# (for example, eagerly start tasks for all sessions, and lazily
# trigger events that allow them to initiate their connection)
_ALL_SESSIONS: tuple[Type[AsyncSession], ...] = (
AsyncSessionEmbedding,
_ALL_SESSIONS: tuple[Type[_AsyncSession], ...] = (
_AsyncSessionEmbedding,
_AsyncSessionFiles,
AsyncSessionLlm,
AsyncSessionRepository,
AsyncSessionSystem,
_AsyncSessionLlm,
_AsyncSessionRepository,
_AsyncSessionSystem,
)

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -1342,30 +1334,30 @@ def _get_session(self, cls: Type[TAsyncSession]) -> TAsyncSession:

@property
@sdk_public_api()
def llm(self) -> AsyncSessionLlm:
def llm(self) -> _AsyncSessionLlm:
"""Return the LLM API client session."""
return self._get_session(AsyncSessionLlm)
return self._get_session(_AsyncSessionLlm)

@property
@sdk_public_api()
def embedding(self) -> AsyncSessionEmbedding:
def embedding(self) -> _AsyncSessionEmbedding:
"""Return the embedding model API client session."""
return self._get_session(AsyncSessionEmbedding)
return self._get_session(_AsyncSessionEmbedding)

@property
def system(self) -> AsyncSessionSystem:
def system(self) -> _AsyncSessionSystem:
"""Return the system API client session."""
return self._get_session(AsyncSessionSystem)
return self._get_session(_AsyncSessionSystem)

@property
def files(self) -> _AsyncSessionFiles:
"""Return the files API client session."""
return self._get_session(_AsyncSessionFiles)

@property
def repository(self) -> AsyncSessionRepository:
def repository(self) -> _AsyncSessionRepository:
"""Return the repository API client session."""
return self._get_session(AsyncSessionRepository)
return self._get_session(_AsyncSessionRepository)

# Convenience methods
# Not yet implemented (server API only supports the same file types as prepare_image)
Expand Down
3 changes: 0 additions & 3 deletions src/lmstudio/plugin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def main(argv: Sequence[str] | None = None) -> int:
warnings.filterwarnings(
"ignore", ".*the plugin API is not yet stable", FutureWarning
)
warnings.filterwarnings(
"ignore", ".*the async API is not yet stable", FutureWarning
)
log_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=log_level)
if not args.dev:
Expand Down
8 changes: 4 additions & 4 deletions src/lmstudio/plugin/hooks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from anyio import move_on_after

from ...async_api import AsyncSession
from ...async_api import _AsyncSession
from ...schemas import DictObject
from ..._sdk_models import (
# TODO: Define aliases at schema generation time
Expand All @@ -32,13 +32,13 @@

# Available as lmstudio.plugin.hooks.*
__all__ = [
"AsyncSessionPlugins",
"_AsyncSessionPlugins",
"TPluginConfigSchema",
"TGlobalConfigSchema",
]


class AsyncSessionPlugins(AsyncSession):
class _AsyncSessionPlugins(_AsyncSession):
"""Async client session for the plugins namespace."""

API_NAMESPACE = "plugins"
Expand All @@ -63,7 +63,7 @@ class HookController(Generic[TPluginRequest, TPluginConfigSchema, TGlobalConfigS

def __init__(
self,
session: AsyncSessionPlugins,
session: _AsyncSessionPlugins,
request: TPluginRequest,
plugin_config_schema: type[TPluginConfigSchema],
global_config_schema: type[TGlobalConfigSchema],
Expand Down
Loading