Skip to content

Commit eecafa8

Browse files
API token in SDK
1 parent ae837ac commit eecafa8

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

src/lmstudio/async_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,9 +1492,11 @@ async def embed(
14921492
class AsyncClient(ClientBase):
14931493
"""Async SDK client interface."""
14941494

1495-
def __init__(self, api_host: str | None = None) -> None:
1495+
def __init__(
1496+
self, api_host: str | None = None, api_token: str | None = None
1497+
) -> None:
14961498
"""Initialize API client."""
1497-
super().__init__(api_host)
1499+
super().__init__(api_host, api_token)
14981500
self._resources = AsyncExitStack()
14991501
self._sessions: dict[str, _AsyncSession] = {}
15001502
self._task_manager = AsyncTaskManager()

src/lmstudio/json_api.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import copy
1414
import inspect
1515
import json
16+
import re
1617
import sys
1718
import uuid
1819
import warnings
@@ -197,6 +198,10 @@
197198

198199
DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour
199200

201+
_LMSTUDIO_API_TOKEN_REGEX = re.compile(
202+
r"^sk-lm-(?P<clientIdentifier>[A-Za-z0-9]{8}):(?P<clientPasskey>[A-Za-z0-9]{20})$"
203+
)
204+
200205
# Require a coroutine (not just any awaitable) for run_coroutine_threadsafe compatibility
201206
SendMessageAsync: TypeAlias = Callable[[DictObject], Coroutine[Any, Any, None]]
202207

@@ -2032,10 +2037,12 @@ def _ensure_connected(self, usage: str) -> None | NoReturn:
20322037
class ClientBase:
20332038
"""Common base class for SDK client interfaces."""
20342039

2035-
def __init__(self, api_host: str | None = None) -> None:
2040+
def __init__(
2041+
self, api_host: str | None = None, api_token: str | None = None
2042+
) -> None:
20362043
"""Initialize API client."""
20372044
self._api_host = api_host
2038-
self._auth_details = self._create_auth_message()
2045+
self._auth_details = self._create_auth_message(api_token)
20392046

20402047
@property
20412048
def api_host(self) -> str:
@@ -2086,16 +2093,35 @@ def _format_auth_message(
20862093
# sufficient to prevent accidental conflicts and, in combination with secure
20872094
# websocket support, would be sufficient to ensure that access to the running
20882095
# client was required to extract the auth details.
2089-
client_identifier = client_id if client_id is not None else str(uuid.uuid4())
2096+
client_identifier = (
2097+
client_id if client_id is not None else f"guest:{str(uuid.uuid4())}"
2098+
)
20902099
client_passkey = client_key if client_key is not None else str(uuid.uuid4())
20912100
return {
20922101
"authVersion": 1,
20932102
"clientIdentifier": client_identifier,
20942103
"clientPasskey": client_passkey,
20952104
}
20962105

2097-
def _create_auth_message(self) -> DictObject:
2106+
def _create_auth_message(self, api_token: str | None = None) -> DictObject:
20982107
"""Create an LM Studio websocket authentication message."""
2108+
if api_token is not None:
2109+
match = _LMSTUDIO_API_TOKEN_REGEX.match(api_token)
2110+
if match is None:
2111+
raise LMStudioValueError(
2112+
"The api_token argument does not look like a valid LM Studio API token.\n\n"
2113+
"LM Studio API tokens are obtained from LM Studio, and they look like this:\n"
2114+
"sk-lm-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx."
2115+
)
2116+
groups = match.groupdict()
2117+
client_identifier = groups.get("clientIdentifier")
2118+
client_passkey = groups.get("clientPasskey")
2119+
if client_identifier is None or client_passkey is None:
2120+
raise LMStudioValueError(
2121+
"Unexpected error parsing api_token: required token fields were not detected."
2122+
)
2123+
return self._format_auth_message(client_identifier, client_passkey)
2124+
20992125
return self._format_auth_message()
21002126

21012127

src/lmstudio/sync_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,9 +1540,11 @@ def embed(
15401540
class Client(ClientBase):
15411541
"""Synchronous SDK client interface."""
15421542

1543-
def __init__(self, api_host: str | None = None) -> None:
1543+
def __init__(
1544+
self, api_host: str | None = None, api_token: str | None = None
1545+
) -> None:
15441546
"""Initialize API client."""
1545-
super().__init__(api_host)
1547+
super().__init__(api_host, api_token)
15461548
self._resources = rm = ExitStack()
15471549
self._ws_thread = ws_thread = AsyncWebsocketThread(dict(client=repr(self)))
15481550
ws_thread.start()

0 commit comments

Comments
 (0)