|
13 | 13 | import copy |
14 | 14 | import inspect |
15 | 15 | import json |
| 16 | +import re |
16 | 17 | import sys |
17 | 18 | import uuid |
18 | 19 | import warnings |
|
197 | 198 |
|
198 | 199 | DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour |
199 | 200 |
|
| 201 | +_LMSTUDIO_API_TOKEN_REGEX = re.compile( |
| 202 | + r"^sk-lm-(?P<clientIdentifier>[A-Za-z0-9]{8}):(?P<clientPasskey>[A-Za-z0-9]{20})$" |
| 203 | +) |
| 204 | + |
200 | 205 | # Require a coroutine (not just any awaitable) for run_coroutine_threadsafe compatibility |
201 | 206 | SendMessageAsync: TypeAlias = Callable[[DictObject], Coroutine[Any, Any, None]] |
202 | 207 |
|
@@ -2032,10 +2037,12 @@ def _ensure_connected(self, usage: str) -> None | NoReturn: |
2032 | 2037 | class ClientBase: |
2033 | 2038 | """Common base class for SDK client interfaces.""" |
2034 | 2039 |
|
2035 | | - def __init__(self, api_host: str | None = None) -> None: |
| 2040 | + def __init__( |
| 2041 | + self, api_host: str | None = None, api_token: str | None = None |
| 2042 | + ) -> None: |
2036 | 2043 | """Initialize API client.""" |
2037 | 2044 | self._api_host = api_host |
2038 | | - self._auth_details = self._create_auth_message() |
| 2045 | + self._auth_details = self._create_auth_message(api_token) |
2039 | 2046 |
|
2040 | 2047 | @property |
2041 | 2048 | def api_host(self) -> str: |
@@ -2086,16 +2093,35 @@ def _format_auth_message( |
2086 | 2093 | # sufficient to prevent accidental conflicts and, in combination with secure |
2087 | 2094 | # websocket support, would be sufficient to ensure that access to the running |
2088 | 2095 | # client was required to extract the auth details. |
2089 | | - client_identifier = client_id if client_id is not None else str(uuid.uuid4()) |
| 2096 | + client_identifier = ( |
| 2097 | + client_id if client_id is not None else f"guest:{str(uuid.uuid4())}" |
| 2098 | + ) |
2090 | 2099 | client_passkey = client_key if client_key is not None else str(uuid.uuid4()) |
2091 | 2100 | return { |
2092 | 2101 | "authVersion": 1, |
2093 | 2102 | "clientIdentifier": client_identifier, |
2094 | 2103 | "clientPasskey": client_passkey, |
2095 | 2104 | } |
2096 | 2105 |
|
2097 | | - def _create_auth_message(self) -> DictObject: |
| 2106 | + def _create_auth_message(self, api_token: str | None = None) -> DictObject: |
2098 | 2107 | """Create an LM Studio websocket authentication message.""" |
| 2108 | + if api_token is not None: |
| 2109 | + match = _LMSTUDIO_API_TOKEN_REGEX.match(api_token) |
| 2110 | + if match is None: |
| 2111 | + raise LMStudioValueError( |
| 2112 | + "The api_token argument does not look like a valid LM Studio API token.\n\n" |
| 2113 | + "LM Studio API tokens are obtained from LM Studio, and they look like this:\n" |
| 2114 | + "sk-lm-xxxxxxxxxxxxxxxxxxxxxxxxxxxxx." |
| 2115 | + ) |
| 2116 | + groups = match.groupdict() |
| 2117 | + client_identifier = groups.get("clientIdentifier") |
| 2118 | + client_passkey = groups.get("clientPasskey") |
| 2119 | + if client_identifier is None or client_passkey is None: |
| 2120 | + raise LMStudioValueError( |
| 2121 | + "Unexpected error parsing api_token: required token fields were not detected." |
| 2122 | + ) |
| 2123 | + return self._format_auth_message(client_identifier, client_passkey) |
| 2124 | + |
2099 | 2125 | return self._format_auth_message() |
2100 | 2126 |
|
2101 | 2127 |
|
|
0 commit comments