Skip to content

Commit 1a08e81

Browse files
authored
Avoid relying on the optional REST HTTP server (#142)
* Default local API server ports are now extracted from lmstudio-js * The singular DEFAULT_API_HOST is no longer defined * New public APIs to check API host validity * New public APIs to scan for a valid local API host on default ports * Omitting the API host now results in a dynamic scan for a valid local API host when enter the client instance context * Attempting to access the client api_host field when it is still to be determined now raises an exception * Synchronous session API is now also marked as private Closes #141
1 parent cd22910 commit 1a08e81

File tree

12 files changed

+233
-66
lines changed

12 files changed

+233
-66
lines changed

sdk-schema/sync-sdk-schema.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
PythonVersion,
4242
)
4343

44-
_THIS_DIR = Path(__file__).parent
44+
45+
_THIS_FILE = Path(__file__)
46+
_THIS_DIR = _THIS_FILE.parent
4547
_LMSJS_DIR = _THIS_DIR / "lmstudio-js"
4648
_EXPORTER_DIR = _LMSJS_DIR / "packages/lms-json-schema"
4749
_SCHEMA_DIR = _EXPORTER_DIR / "schemas"
@@ -51,6 +53,12 @@
5153
_TEMPLATE_DIR = _THIS_DIR / "_templates"
5254
_MODEL_DIR = _THIS_DIR.parent / "src/lmstudio/_sdk_models"
5355
_MODEL_PATH = _MODEL_DIR / "__init__.py"
56+
_LMSJS_PORTS_PATH = _LMSJS_DIR / "packages/lms-common/src/apiServerPorts.ts"
57+
_PY_PORTS_PATH = _THIS_DIR.parent / "src/lmstudio/_api_server_ports.py"
58+
59+
GENERATED_SOURCE_HEADER = f"""\
60+
# Automatically generated by {_THIS_FILE.name}. DO NOT EDIT THIS FILE!
61+
""".splitlines()
5462

5563
# The following schemas are not actually used anywhere,
5664
# so they're excluded to avoid any conflicts with automatically
@@ -630,13 +638,40 @@ def _generate_data_model_from_json_schema() -> None:
630638
if line.startswith("class"):
631639
break
632640
updated_source_lines[idx:idx] = lines_to_insert
641+
# Insert auto-generated code header
642+
updated_source_lines[0:0] = GENERATED_SOURCE_HEADER
633643
_MODEL_PATH.write_text("\n".join(updated_source_lines) + "\n")
634644

635645

646+
def _sync_default_port_list() -> None:
647+
"""Copy the list of default ports to check for the local API server."""
648+
print("Extracting default port list...")
649+
print(f" Reading {_LMSJS_PORTS_PATH}")
650+
lmsjs_source = _LMSJS_PORTS_PATH.read_text()
651+
START_PORTS = "apiServerPorts = ["
652+
END_PORTS = "];"
653+
_, found, remaining_text = lmsjs_source.partition(START_PORTS)
654+
if not found:
655+
raise RuntimeError(f"Failed to find {START_PORTS} in {lmsjs_source}")
656+
ports_text, suffix_found, _ = remaining_text.partition(END_PORTS)
657+
if not suffix_found:
658+
raise RuntimeError(f"Failed to find {END_PORTS} in {remaining_text}")
659+
default_ports = [*map(int, ports_text.split(","))]
660+
if not default_ports:
661+
raise RuntimeError("Failed to extract any default ports")
662+
py_source_lines = [
663+
*GENERATED_SOURCE_HEADER,
664+
f"default_api_ports = ({','.join(map(str, default_ports))})",
665+
]
666+
print(f" Writing {_PY_PORTS_PATH}")
667+
_PY_PORTS_PATH.write_text("\n".join(py_source_lines) + "\n")
668+
669+
636670
def _main() -> None:
637671
if sys.argv[1:] == ["--regen-schema"] or not _SCHEMA_PATH.exists():
638672
_export_zod_schemas_to_json_schema()
639673
_generate_data_model_from_json_schema()
674+
_sync_default_port_list()
640675
print("Running automatic formatter after data model code generation")
641676
subprocess.run(["tox", "-e", "format"])
642677

src/lmstudio/_api_server_ports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Automatically generated by sync-sdk-schema.py. DO NOT EDIT THIS FILE!
2+
default_api_ports = (41343, 52993, 16141, 39414, 22931)

src/lmstudio/_sdk_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Automatically generated by sync-sdk-schema.py. DO NOT EDIT THIS FILE!
12
from __future__ import annotations
23
from typing import Annotated, Any, ClassVar, Literal, Mapping, Sequence, TypedDict
34
from msgspec import Meta, field

src/lmstudio/async_api.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
TypeIs,
2929
)
3030

31+
import httpx
32+
3133
from httpx_ws import AsyncWebSocketSession
3234

3335
from .sdk_api import (
@@ -1518,6 +1520,8 @@ def __init__(self, api_host: str | None = None) -> None:
15181520
async def __aenter__(self) -> Self:
15191521
# Handle reentrancy the same way files do:
15201522
# allow nested use as a CM, but close on the first exit
1523+
with sdk_public_api():
1524+
await self._ensure_api_host_is_valid()
15211525
if not self._sessions:
15221526
rm = self._resources
15231527
await rm.enter_async_context(self._task_manager)
@@ -1536,6 +1540,43 @@ async def aclose(self) -> None:
15361540
"""Close any started client sessions."""
15371541
await self._resources.aclose()
15381542

1543+
@staticmethod
1544+
async def _query_probe_url(url: str) -> httpx.Response:
1545+
async with httpx.AsyncClient() as client:
1546+
return await client.get(url, timeout=1)
1547+
1548+
@classmethod
1549+
@sdk_public_api_async()
1550+
async def is_valid_api_host(cls, api_host: str) -> bool:
1551+
"""Report whether the given API host is running an API server instance."""
1552+
probe_url = cls._get_probe_url(api_host)
1553+
try:
1554+
probe_response = await cls._query_probe_url(probe_url)
1555+
except (httpx.ConnectTimeout, httpx.ConnectError):
1556+
return False
1557+
return cls._check_probe_response(probe_response)
1558+
1559+
@classmethod
1560+
@sdk_public_api_async()
1561+
async def find_default_local_api_host(cls) -> str | None:
1562+
"""Query local ports for a running API server instance."""
1563+
for api_host in cls._iter_default_api_hosts():
1564+
if await cls.is_valid_api_host(api_host):
1565+
return api_host
1566+
return None
1567+
1568+
async def _ensure_api_host_is_valid(self) -> None:
1569+
specified_api_host = self._api_host
1570+
if specified_api_host is None:
1571+
api_host = await self.find_default_local_api_host()
1572+
elif await self.is_valid_api_host(specified_api_host):
1573+
api_host = specified_api_host
1574+
else:
1575+
api_host = None
1576+
if api_host is None:
1577+
raise self._get_probe_failure_error(specified_api_host)
1578+
self._api_host = api_host
1579+
15391580
def _get_session(self, cls: Type[TAsyncSession]) -> TAsyncSession:
15401581
"""Get the client session of the given type."""
15411582
namespace = cls.API_NAMESPACE

src/lmstudio/json_api.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
Self,
4444
)
4545

46+
import httpx
4647

4748
from msgspec import Struct, convert, defstruct, to_builtins
4849

50+
from . import _api_server_ports
4951
from .sdk_api import (
5052
LMStudioError,
5153
LMStudioRuntimeError,
@@ -190,7 +192,6 @@
190192
T = TypeVar("T")
191193
TStruct = TypeVar("TStruct", bound=AnyLMStudioStruct)
192194

193-
DEFAULT_API_HOST = "localhost:1234"
194195
DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour
195196

196197
# Require a coroutine (not just any awaitable) for run_coroutine_threadsafe compatibility
@@ -1964,9 +1965,45 @@ class ClientBase:
19641965

19651966
def __init__(self, api_host: str | None = None) -> None:
19661967
"""Initialize API client."""
1967-
self.api_host = api_host if api_host else DEFAULT_API_HOST
1968+
self._api_host = api_host
19681969
self._auth_details = self._create_auth_message()
19691970

1971+
@property
1972+
def api_host(self) -> str:
1973+
api_host = self._api_host
1974+
if api_host is None:
1975+
raise LMStudioRuntimeError("Local API host port is not yet resolved.")
1976+
return api_host
1977+
1978+
_DEFAULT_API_PORTS = _api_server_ports.default_api_ports
1979+
1980+
@staticmethod
1981+
def _get_probe_url(api_host: str) -> str:
1982+
return f"http://{api_host}/lmstudio-greeting"
1983+
1984+
@classmethod
1985+
def _iter_default_api_hosts(cls) -> Iterable[str]:
1986+
for port in cls._DEFAULT_API_PORTS:
1987+
api_host = f"127.0.0.1:{port}"
1988+
yield api_host
1989+
1990+
@staticmethod
1991+
def _check_probe_response(response: httpx.Response) -> bool:
1992+
"""Returns true if the probe response indicates a valid API server."""
1993+
if response.status_code != httpx.codes.OK:
1994+
return False
1995+
response_data = response.json()
1996+
# Valid probe response format: {"lmstudio":true}
1997+
return isinstance(response_data, dict) and response_data.get("lmstudio", False)
1998+
1999+
@staticmethod
2000+
def _get_probe_failure_error(api_host: str | None) -> LMStudioClientError:
2001+
if api_host is None:
2002+
api_host = "any default port"
2003+
problem = f"LM Studio is not reachable at {api_host}"
2004+
suggestion = "Is LM Studio running?"
2005+
return LMStudioClientError(f"{problem}. {suggestion}")
2006+
19702007
@staticmethod
19712008
def _format_auth_message(
19722009
client_id: str | None = None, client_key: str | None = None

0 commit comments

Comments
 (0)