Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion sdk-schema/sync-sdk-schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
PythonVersion,
)

_THIS_DIR = Path(__file__).parent

_THIS_FILE = Path(__file__)
_THIS_DIR = _THIS_FILE.parent
_LMSJS_DIR = _THIS_DIR / "lmstudio-js"
_EXPORTER_DIR = _LMSJS_DIR / "packages/lms-json-schema"
_SCHEMA_DIR = _EXPORTER_DIR / "schemas"
Expand All @@ -51,6 +53,12 @@
_TEMPLATE_DIR = _THIS_DIR / "_templates"
_MODEL_DIR = _THIS_DIR.parent / "src/lmstudio/_sdk_models"
_MODEL_PATH = _MODEL_DIR / "__init__.py"
_LMSJS_PORTS_PATH = _LMSJS_DIR / "packages/lms-common/src/apiServerPorts.ts"
_PY_PORTS_PATH = _THIS_DIR.parent / "src/lmstudio/_api_server_ports.py"

GENERATED_SOURCE_HEADER = f"""\
# Automatically generated by {_THIS_FILE.name}. DO NOT EDIT THIS FILE!
""".splitlines()

# The following schemas are not actually used anywhere,
# so they're excluded to avoid any conflicts with automatically
Expand Down Expand Up @@ -630,13 +638,40 @@ def _generate_data_model_from_json_schema() -> None:
if line.startswith("class"):
break
updated_source_lines[idx:idx] = lines_to_insert
# Insert auto-generated code header
updated_source_lines[0:0] = GENERATED_SOURCE_HEADER
_MODEL_PATH.write_text("\n".join(updated_source_lines) + "\n")


def _sync_default_port_list() -> None:
"""Copy the list of default ports to check for the local API server."""
print("Extracting default port list...")
print(f" Reading {_LMSJS_PORTS_PATH}")
lmsjs_source = _LMSJS_PORTS_PATH.read_text()
START_PORTS = "apiServerPorts = ["
END_PORTS = "];"
_, found, remaining_text = lmsjs_source.partition(START_PORTS)
if not found:
raise RuntimeError(f"Failed to find {START_PORTS} in {lmsjs_source}")
ports_text, suffix_found, _ = remaining_text.partition(END_PORTS)
if not suffix_found:
raise RuntimeError(f"Failed to find {END_PORTS} in {remaining_text}")
default_ports = [*map(int, ports_text.split(","))]
if not default_ports:
raise RuntimeError("Failed to extract any default ports")
py_source_lines = [
*GENERATED_SOURCE_HEADER,
f"default_api_ports = ({','.join(map(str, default_ports))})",
]
print(f" Writing {_PY_PORTS_PATH}")
_PY_PORTS_PATH.write_text("\n".join(py_source_lines) + "\n")


def _main() -> None:
if sys.argv[1:] == ["--regen-schema"] or not _SCHEMA_PATH.exists():
_export_zod_schemas_to_json_schema()
_generate_data_model_from_json_schema()
_sync_default_port_list()
print("Running automatic formatter after data model code generation")
subprocess.run(["tox", "-e", "format"])

Expand Down
2 changes: 2 additions & 0 deletions src/lmstudio/_api_server_ports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Automatically generated by sync-sdk-schema.py. DO NOT EDIT THIS FILE!
default_api_ports = (41343, 52993, 16141, 39414, 22931)
1 change: 1 addition & 0 deletions src/lmstudio/_sdk_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Automatically generated by sync-sdk-schema.py. DO NOT EDIT THIS FILE!
from __future__ import annotations
from typing import Annotated, Any, ClassVar, Literal, Mapping, Sequence, TypedDict
from msgspec import Meta, field
Expand Down
41 changes: 41 additions & 0 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
TypeIs,
)

import httpx

from httpx_ws import AsyncWebSocketSession

from .sdk_api import (
Expand Down Expand Up @@ -1518,6 +1520,8 @@ def __init__(self, api_host: str | None = None) -> None:
async def __aenter__(self) -> Self:
# Handle reentrancy the same way files do:
# allow nested use as a CM, but close on the first exit
with sdk_public_api():
await self._ensure_api_host_is_valid()
if not self._sessions:
rm = self._resources
await rm.enter_async_context(self._task_manager)
Expand All @@ -1536,6 +1540,43 @@ async def aclose(self) -> None:
"""Close any started client sessions."""
await self._resources.aclose()

@staticmethod
async def _query_probe_url(url: str) -> httpx.Response:
async with httpx.AsyncClient() as client:
return await client.get(url, timeout=1)

@classmethod
@sdk_public_api_async()
async def is_valid_api_host(cls, api_host: str) -> bool:
"""Report whether the given API host is running an API server instance."""
probe_url = cls._get_probe_url(api_host)
try:
probe_response = await cls._query_probe_url(probe_url)
except (httpx.ConnectTimeout, httpx.ConnectError):
return False
return cls._check_probe_response(probe_response)

@classmethod
@sdk_public_api_async()
async def find_default_local_api_host(cls) -> str | None:
"""Query local ports for a running API server instance."""
for api_host in cls._iter_default_api_hosts():
if await cls.is_valid_api_host(api_host):
return api_host
return None

async def _ensure_api_host_is_valid(self) -> None:
specified_api_host = self._api_host
if specified_api_host is None:
api_host = await self.find_default_local_api_host()
elif await self.is_valid_api_host(specified_api_host):
api_host = specified_api_host
else:
api_host = None
if api_host is None:
raise self._get_probe_failure_error(specified_api_host)
self._api_host = api_host

def _get_session(self, cls: Type[TAsyncSession]) -> TAsyncSession:
"""Get the client session of the given type."""
namespace = cls.API_NAMESPACE
Expand Down
41 changes: 39 additions & 2 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
Self,
)

import httpx

from msgspec import Struct, convert, defstruct, to_builtins

from . import _api_server_ports
from .sdk_api import (
LMStudioError,
LMStudioRuntimeError,
Expand Down Expand Up @@ -190,7 +192,6 @@
T = TypeVar("T")
TStruct = TypeVar("TStruct", bound=AnyLMStudioStruct)

DEFAULT_API_HOST = "localhost:1234"
DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour

# Require a coroutine (not just any awaitable) for run_coroutine_threadsafe compatibility
Expand Down Expand Up @@ -1964,9 +1965,45 @@ class ClientBase:

def __init__(self, api_host: str | None = None) -> None:
"""Initialize API client."""
self.api_host = api_host if api_host else DEFAULT_API_HOST
self._api_host = api_host
self._auth_details = self._create_auth_message()

@property
def api_host(self) -> str:
api_host = self._api_host
if api_host is None:
raise LMStudioRuntimeError("Local API host port is not yet resolved.")
return api_host

_DEFAULT_API_PORTS = _api_server_ports.default_api_ports

@staticmethod
def _get_probe_url(api_host: str) -> str:
return f"http://{api_host}/lmstudio-greeting"

@classmethod
def _iter_default_api_hosts(cls) -> Iterable[str]:
for port in cls._DEFAULT_API_PORTS:
api_host = f"127.0.0.1:{port}"
yield api_host

@staticmethod
def _check_probe_response(response: httpx.Response) -> bool:
"""Returns true if the probe response indicates a valid API server."""
if response.status_code != httpx.codes.OK:
return False
response_data = response.json()
# Valid probe response format: {"lmstudio":true}
return isinstance(response_data, dict) and response_data.get("lmstudio", False)

@staticmethod
def _get_probe_failure_error(api_host: str | None) -> LMStudioClientError:
if api_host is None:
api_host = "any default port"
problem = f"LM Studio is not reachable at {api_host}"
suggestion = "Is LM Studio running?"
return LMStudioClientError(f"{problem}. {suggestion}")

@staticmethod
def _format_auth_message(
client_id: str | None = None, client_key: str | None = None
Expand Down
Loading