diff --git a/sdk-schema/sync-sdk-schema.py b/sdk-schema/sync-sdk-schema.py index e45ee92..ab8a31a 100755 --- a/sdk-schema/sync-sdk-schema.py +++ b/sdk-schema/sync-sdk-schema.py @@ -41,7 +41,9 @@ PythonVersion, ) -_THIS_DIR = Path(__file__).parent + +_THIS_FILE = Path(__file__) +_THIS_DIR = _THIS_FILE.parent _LMSJS_DIR = _THIS_DIR / "lmstudio-js" _EXPORTER_DIR = _LMSJS_DIR / "packages/lms-json-schema" _SCHEMA_DIR = _EXPORTER_DIR / "schemas" @@ -51,6 +53,12 @@ _TEMPLATE_DIR = _THIS_DIR / "_templates" _MODEL_DIR = _THIS_DIR.parent / "src/lmstudio/_sdk_models" _MODEL_PATH = _MODEL_DIR / "__init__.py" +_LMSJS_PORTS_PATH = _LMSJS_DIR / "packages/lms-common/src/apiServerPorts.ts" +_PY_PORTS_PATH = _THIS_DIR.parent / "src/lmstudio/_api_server_ports.py" + +GENERATED_SOURCE_HEADER = f"""\ +# Automatically generated by {_THIS_FILE.name}. DO NOT EDIT THIS FILE! +""".splitlines() # The following schemas are not actually used anywhere, # so they're excluded to avoid any conflicts with automatically @@ -630,13 +638,40 @@ def _generate_data_model_from_json_schema() -> None: if line.startswith("class"): break updated_source_lines[idx:idx] = lines_to_insert + # Insert auto-generated code header + updated_source_lines[0:0] = GENERATED_SOURCE_HEADER _MODEL_PATH.write_text("\n".join(updated_source_lines) + "\n") +def _sync_default_port_list() -> None: + """Copy the list of default ports to check for the local API server.""" + print("Extracting default port list...") + print(f" Reading {_LMSJS_PORTS_PATH}") + lmsjs_source = _LMSJS_PORTS_PATH.read_text() + START_PORTS = "apiServerPorts = [" + END_PORTS = "];" + _, found, remaining_text = lmsjs_source.partition(START_PORTS) + if not found: + raise RuntimeError(f"Failed to find {START_PORTS} in {lmsjs_source}") + ports_text, suffix_found, _ = remaining_text.partition(END_PORTS) + if not suffix_found: + raise RuntimeError(f"Failed to find {END_PORTS} in {remaining_text}") + default_ports = [*map(int, ports_text.split(","))] + if not default_ports: + raise RuntimeError("Failed to extract any default ports") + py_source_lines = [ + *GENERATED_SOURCE_HEADER, + f"default_api_ports = ({','.join(map(str, default_ports))})", + ] + print(f" Writing {_PY_PORTS_PATH}") + _PY_PORTS_PATH.write_text("\n".join(py_source_lines) + "\n") + + def _main() -> None: if sys.argv[1:] == ["--regen-schema"] or not _SCHEMA_PATH.exists(): _export_zod_schemas_to_json_schema() _generate_data_model_from_json_schema() + _sync_default_port_list() print("Running automatic formatter after data model code generation") subprocess.run(["tox", "-e", "format"]) diff --git a/src/lmstudio/_api_server_ports.py b/src/lmstudio/_api_server_ports.py new file mode 100644 index 0000000..6a337fe --- /dev/null +++ b/src/lmstudio/_api_server_ports.py @@ -0,0 +1,2 @@ +# Automatically generated by sync-sdk-schema.py. DO NOT EDIT THIS FILE! +default_api_ports = (41343, 52993, 16141, 39414, 22931) diff --git a/src/lmstudio/_sdk_models/__init__.py b/src/lmstudio/_sdk_models/__init__.py index 5c2e83b..e485eeb 100644 --- a/src/lmstudio/_sdk_models/__init__.py +++ b/src/lmstudio/_sdk_models/__init__.py @@ -1,3 +1,4 @@ +# Automatically generated by sync-sdk-schema.py. DO NOT EDIT THIS FILE! from __future__ import annotations from typing import Annotated, Any, ClassVar, Literal, Mapping, Sequence, TypedDict from msgspec import Meta, field diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index c1f26ee..a13050d 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -28,6 +28,8 @@ TypeIs, ) +import httpx + from httpx_ws import AsyncWebSocketSession from .sdk_api import ( @@ -1518,6 +1520,8 @@ def __init__(self, api_host: str | None = None) -> None: async def __aenter__(self) -> Self: # Handle reentrancy the same way files do: # allow nested use as a CM, but close on the first exit + with sdk_public_api(): + await self._ensure_api_host_is_valid() if not self._sessions: rm = self._resources await rm.enter_async_context(self._task_manager) @@ -1536,6 +1540,43 @@ async def aclose(self) -> None: """Close any started client sessions.""" await self._resources.aclose() + @staticmethod + async def _query_probe_url(url: str) -> httpx.Response: + async with httpx.AsyncClient() as client: + return await client.get(url, timeout=1) + + @classmethod + @sdk_public_api_async() + async def is_valid_api_host(cls, api_host: str) -> bool: + """Report whether the given API host is running an API server instance.""" + probe_url = cls._get_probe_url(api_host) + try: + probe_response = await cls._query_probe_url(probe_url) + except (httpx.ConnectTimeout, httpx.ConnectError): + return False + return cls._check_probe_response(probe_response) + + @classmethod + @sdk_public_api_async() + async def find_default_local_api_host(cls) -> str | None: + """Query local ports for a running API server instance.""" + for api_host in cls._iter_default_api_hosts(): + if await cls.is_valid_api_host(api_host): + return api_host + return None + + async def _ensure_api_host_is_valid(self) -> None: + specified_api_host = self._api_host + if specified_api_host is None: + api_host = await self.find_default_local_api_host() + elif await self.is_valid_api_host(specified_api_host): + api_host = specified_api_host + else: + api_host = None + if api_host is None: + raise self._get_probe_failure_error(specified_api_host) + self._api_host = api_host + def _get_session(self, cls: Type[TAsyncSession]) -> TAsyncSession: """Get the client session of the given type.""" namespace = cls.API_NAMESPACE diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 610f953..231e0fa 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -43,9 +43,11 @@ Self, ) +import httpx from msgspec import Struct, convert, defstruct, to_builtins +from . import _api_server_ports from .sdk_api import ( LMStudioError, LMStudioRuntimeError, @@ -190,7 +192,6 @@ T = TypeVar("T") TStruct = TypeVar("TStruct", bound=AnyLMStudioStruct) -DEFAULT_API_HOST = "localhost:1234" DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour # Require a coroutine (not just any awaitable) for run_coroutine_threadsafe compatibility @@ -1964,9 +1965,45 @@ class ClientBase: def __init__(self, api_host: str | None = None) -> None: """Initialize API client.""" - self.api_host = api_host if api_host else DEFAULT_API_HOST + self._api_host = api_host self._auth_details = self._create_auth_message() + @property + def api_host(self) -> str: + api_host = self._api_host + if api_host is None: + raise LMStudioRuntimeError("Local API host port is not yet resolved.") + return api_host + + _DEFAULT_API_PORTS = _api_server_ports.default_api_ports + + @staticmethod + def _get_probe_url(api_host: str) -> str: + return f"http://{api_host}/lmstudio-greeting" + + @classmethod + def _iter_default_api_hosts(cls) -> Iterable[str]: + for port in cls._DEFAULT_API_PORTS: + api_host = f"127.0.0.1:{port}" + yield api_host + + @staticmethod + def _check_probe_response(response: httpx.Response) -> bool: + """Returns true if the probe response indicates a valid API server.""" + if response.status_code != httpx.codes.OK: + return False + response_data = response.json() + # Valid probe response format: {"lmstudio":true} + return isinstance(response_data, dict) and response_data.get("lmstudio", False) + + @staticmethod + def _get_probe_failure_error(api_host: str | None) -> LMStudioClientError: + if api_host is None: + api_host = "any default port" + problem = f"LM Studio is not reachable at {api_host}" + suggestion = "Is LM Studio running?" + return LMStudioClientError(f"{problem}. {suggestion}") + @staticmethod def _format_auth_message( client_id: str | None = None, client_key: str | None = None diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index dfdaafa..6f6ac4f 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -31,6 +31,8 @@ TypeIs, ) +import httpx + # Synchronous API still uses an async websocket (just in a background thread) from httpx_ws import AsyncWebSocketSession @@ -424,7 +426,7 @@ def remote_call( return rpc.receive_result() -class SyncSession(ClientSession["Client", SyncLMStudioWebsocket]): +class _SyncSession(ClientSession["Client", SyncLMStudioWebsocket]): """Sync client session interfaces applicable to all API namespaces.""" def __init__(self, client: "Client") -> None: @@ -502,7 +504,7 @@ def remote_call( TSyncSessionModel = TypeVar( - "TSyncSessionModel", bound="SyncSessionModel[Any, Any, Any, Any]" + "TSyncSessionModel", bound="_SyncSessionModel[Any, Any, Any, Any]" ) TModelHandle = TypeVar("TModelHandle", bound="SyncModelHandle[Any]") @@ -558,7 +560,7 @@ def model( class DownloadedEmbeddingModel( DownloadedModel[ EmbeddingModelInfo, - "SyncSessionEmbedding", + "_SyncSessionEmbedding", EmbeddingLoadModelConfig, EmbeddingLoadModelConfigDict, "EmbeddingModel", @@ -566,7 +568,9 @@ class DownloadedEmbeddingModel( ): """Download listing for an embedding model.""" - def __init__(self, model_info: DictObject, session: "SyncSessionEmbedding") -> None: + def __init__( + self, model_info: DictObject, session: "_SyncSessionEmbedding" + ) -> None: """Initialize downloaded embedding model details.""" super().__init__(EmbeddingModelInfo, model_info, session) @@ -574,7 +578,7 @@ def __init__(self, model_info: DictObject, session: "SyncSessionEmbedding") -> N class DownloadedLlm( DownloadedModel[ LlmInfo, - "SyncSessionLlm", + "_SyncSessionLlm", LlmLoadModelConfig, LlmLoadModelConfigDict, "LLM", @@ -582,7 +586,7 @@ class DownloadedLlm( ): """Download listing for an LLM.""" - def __init__(self, model_info: DictObject, session: "SyncSessionLlm") -> None: + def __init__(self, model_info: DictObject, session: "_SyncSessionLlm") -> None: """Initialize downloaded embedding model details.""" super().__init__(LlmInfo, model_info, session) @@ -590,7 +594,7 @@ def __init__(self, model_info: DictObject, session: "SyncSessionLlm") -> None: AnyDownloadedModel: TypeAlias = DownloadedModel[Any, Any, Any, Any, Any] -class SyncSessionSystem(SyncSession): +class _SyncSessionSystem(_SyncSession): """Sync client session for the system namespace.""" API_NAMESPACE = "system" @@ -618,7 +622,7 @@ def _process_download_listing(self, model_info: DictObject) -> AnyDownloadedMode ) -class _SyncSessionFiles(SyncSession): +class _SyncSessionFiles(_SyncSession): """Sync client session for the files namespace.""" API_NAMESPACE = "files" @@ -645,7 +649,7 @@ def prepare_image(self, src: LocalFileInput, name: str | None = None) -> FileHan return self._fetch_file_handle(file_data) -class ModelDownloadOption(ModelDownloadOptionBase[SyncSession]): +class ModelDownloadOption(ModelDownloadOptionBase[_SyncSession]): """A single download option for a model search result.""" @sdk_public_api() @@ -660,7 +664,7 @@ def download( return channel.wait_for_result() -class AvailableModel(AvailableModelBase[SyncSession]): +class AvailableModel(AvailableModelBase[_SyncSession]): """A model available for download from the model repository.""" @sdk_public_api() @@ -676,7 +680,7 @@ def get_download_options( return final -class SyncSessionRepository(SyncSession): +class _SyncSessionRepository(_SyncSession): """Sync client session for the repository namespace.""" API_NAMESPACE = "repository" @@ -697,8 +701,8 @@ def search_models( TDownloadedModel = TypeVar("TDownloadedModel", bound=AnyDownloadedModel) -class SyncSessionModel( - SyncSession, +class _SyncSessionModel( + _SyncSession, Generic[TModelHandle, TLoadConfig, TLoadConfigDict, TDownloadedModel], ): """Sync client session for a model (LLM/embedding) namespace.""" @@ -706,7 +710,7 @@ class SyncSessionModel( _API_TYPES: Type[ModelSessionTypes[TLoadConfig]] @property - def _system_session(self) -> SyncSessionSystem: + def _system_session(self) -> _SyncSessionSystem: return self._client.system @property @@ -990,8 +994,8 @@ def cancel(self) -> None: self._channel.cancel() -class SyncSessionLlm( - SyncSessionModel[ +class _SyncSessionLlm( + _SyncSessionModel[ "LLM", LlmLoadModelConfig, LlmLoadModelConfigDict, @@ -1100,8 +1104,8 @@ def _apply_prompt_template( return response.get("formatted", "") if response else "" -class SyncSessionEmbedding( - SyncSessionModel[ +class _SyncSessionEmbedding( + _SyncSessionModel[ "EmbeddingModel", EmbeddingLoadModelConfig, EmbeddingLoadModelConfigDict, @@ -1187,7 +1191,7 @@ def get_context_length(self) -> int: AnySyncModel: TypeAlias = SyncModelHandle[Any] -class LLM(SyncModelHandle[SyncSessionLlm]): +class LLM(SyncModelHandle[_SyncSessionLlm]): """Reference to a loaded LLM model.""" @sdk_public_api() @@ -1517,7 +1521,7 @@ def apply_prompt_template( ) -class EmbeddingModel(SyncModelHandle[SyncSessionEmbedding]): +class EmbeddingModel(SyncModelHandle[_SyncSessionEmbedding]): """Reference to a loaded embedding model.""" # Alas, type hints don't properly support distinguishing str vs Iterable[str]: @@ -1530,7 +1534,7 @@ def embed( return self._session._embed(self.identifier, input) -TSyncSession = TypeVar("TSyncSession", bound=SyncSession) +TSyncSession = TypeVar("TSyncSession", bound=_SyncSession) class Client(ClientBase): @@ -1543,7 +1547,7 @@ def __init__(self, api_host: str | None = None) -> None: self._ws_thread = ws_thread = AsyncWebsocketThread(dict(client=repr(self))) ws_thread.start() rm.callback(ws_thread.terminate) - self._sessions: dict[str, SyncSession] = {} + self._sessions: dict[str, _SyncSession] = {} # Support GC-based resource management in the sync API by # finalizing at the client layer, and letting its resource # manager handle clearing up everything else @@ -1553,6 +1557,8 @@ def __init__(self, api_host: str | None = None) -> None: def __enter__(self) -> Self: # Handle reentrancy the same way files do: # allow nested use as a CM, but close on the first exit + with sdk_public_api(): + self._ensure_api_host_is_valid() return self def __exit__(self, *args: Any) -> None: @@ -1562,6 +1568,42 @@ def close(self) -> None: """Close any started client sessions.""" self._resources.close() + @staticmethod + def _query_probe_url(url: str) -> httpx.Response: + return httpx.get(url, timeout=1) + + @classmethod + @sdk_public_api() + def is_valid_api_host(cls, api_host: str) -> bool: + """Report whether the given API host is running an API server instance.""" + probe_url = cls._get_probe_url(api_host) + try: + probe_response = cls._query_probe_url(probe_url) + except (httpx.ConnectTimeout, httpx.ConnectError): + return False + return cls._check_probe_response(probe_response) + + @classmethod + @sdk_public_api() + def find_default_local_api_host(cls) -> str | None: + """Query local ports for a running API server instance.""" + for api_host in cls._iter_default_api_hosts(): + if cls.is_valid_api_host(api_host): + return api_host + return None + + def _ensure_api_host_is_valid(self) -> None: + specified_api_host = self._api_host + if specified_api_host is None: + api_host = self.find_default_local_api_host() + elif self.is_valid_api_host(specified_api_host): + api_host = specified_api_host + else: + api_host = None + if api_host is None: + raise self._get_probe_failure_error(specified_api_host) + self._api_host = api_host + # Doing network I/O in properties is generally considered undesirable. # The async API can't perform network I/O in properties at all. # Unlike the async API (which follows the principles of structured @@ -1593,20 +1635,20 @@ def _get_session(self, cls: Type[TSyncSession]) -> TSyncSession: @property @sdk_public_api() - def llm(self) -> SyncSessionLlm: + def llm(self) -> _SyncSessionLlm: """Return the LLM API client session.""" - return self._get_session(SyncSessionLlm) + return self._get_session(_SyncSessionLlm) @property @sdk_public_api() - def embedding(self) -> SyncSessionEmbedding: + def embedding(self) -> _SyncSessionEmbedding: """Return the embedding model API client session.""" - return self._get_session(SyncSessionEmbedding) + return self._get_session(_SyncSessionEmbedding) @property - def system(self) -> SyncSessionSystem: + def system(self) -> _SyncSessionSystem: """Return the system API client session.""" - return self._get_session(SyncSessionSystem) + return self._get_session(_SyncSessionSystem) @property def files(self) -> _SyncSessionFiles: @@ -1614,9 +1656,9 @@ def files(self) -> _SyncSessionFiles: return self._get_session(_SyncSessionFiles) @property - def repository(self) -> SyncSessionRepository: + def repository(self) -> _SyncSessionRepository: """Return the repository API client session.""" - return self._get_session(SyncSessionRepository) + return self._get_session(_SyncSessionRepository) # Convenience methods # Not yet implemented (server API only supports the same file types as prepare_image) @@ -1681,6 +1723,7 @@ def get_default_client(api_host: str | None = None) -> Client: configure_default_client(api_host) if _default_client is None: _default_client = Client(_default_api_host) + _default_client._ensure_api_host_is_valid() return _default_client diff --git a/tests/async/test_sdk_bypass_async.py b/tests/async/test_sdk_bypass_async.py index ded4f19..0a09e6b 100644 --- a/tests/async/test_sdk_bypass_async.py +++ b/tests/async/test_sdk_bypass_async.py @@ -13,17 +13,20 @@ from httpx_ws import aconnect_ws, AsyncWebSocketSession +from lmstudio import AsyncClient + @pytest.mark.asyncio @pytest.mark.lmstudio async def test_connect_and_predict_async(caplog: Any) -> None: - base_url = "localhost:1234" + # Access the default API host directly + api_host = await AsyncClient.find_default_local_api_host() model_identifier = "hugging-quants/llama-3.2-1b-instruct" prompt = "Hello" caplog.set_level(logging.DEBUG) ws_cm: AsyncContextManager[AsyncWebSocketSession] = aconnect_ws( - f"ws://{base_url}/llm" + f"ws://{api_host}/llm" ) async with ws_cm as ws: diff --git a/tests/support/__init__.py b/tests/support/__init__.py index 5066542..fc2201d 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -32,8 +32,6 @@ THIS_DIR = Path(__file__).parent -LOCAL_API_HOST = "localhost:1234" - #################################################### # Embedding model testing #################################################### diff --git a/tests/sync/test_sdk_bypass_sync.py b/tests/sync/test_sdk_bypass_sync.py index c8dc521..e2ecc2f 100644 --- a/tests/sync/test_sdk_bypass_sync.py +++ b/tests/sync/test_sdk_bypass_sync.py @@ -20,15 +20,18 @@ from httpx_ws import connect_ws, WebSocketSession +from lmstudio import Client + @pytest.mark.lmstudio def test_connect_and_predict_sync(caplog: Any) -> None: - base_url = "localhost:1234" + # Access the default API host directly + api_host = Client.find_default_local_api_host() model_identifier = "hugging-quants/llama-3.2-1b-instruct" prompt = "Hello" caplog.set_level(logging.DEBUG) - ws_cm: ContextManager[WebSocketSession] = connect_ws(f"ws://{base_url}/llm") + ws_cm: ContextManager[WebSocketSession] = connect_ws(f"ws://{api_host}/llm") with ws_cm as ws: # Authenticate diff --git a/tests/test_convenience_api.py b/tests/test_convenience_api.py index 05eaf87..0aadcfa 100644 --- a/tests/test_convenience_api.py +++ b/tests/test_convenience_api.py @@ -23,13 +23,12 @@ def test_get_default_client() -> None: assert isinstance(client, lms.Client) # Setting the API host after creation is disallowed (even if it is consistent) with pytest.raises(lms.LMStudioClientError, match="already created"): - lms.get_default_client("localhost:1234") + lms.get_default_client(client.api_host) # Ensure configured API host is used lms.sync_api._reset_default_client() try: - with pytest.raises(lms.LMStudioWebsocketError, match="not reachable"): - # Actually try to use the client in order to force a connection attempt - lms.get_default_client(closed_api_host()).list_loaded_models() + with pytest.raises(lms.LMStudioClientError, match="not reachable"): + lms.get_default_client(closed_api_host()) finally: lms.sync_api._reset_default_client() @@ -41,14 +40,13 @@ def test_configure_default_client() -> None: assert isinstance(client, lms.Client) # Setting the API host after creation is disallowed (even if it is consistent) with pytest.raises(lms.LMStudioClientError, match="already created"): - lms.configure_default_client("localhost:1234") + lms.configure_default_client(client.api_host) # Ensure configured API host is used lms.sync_api._reset_default_client() try: lms.configure_default_client(closed_api_host()) - with pytest.raises(lms.LMStudioWebsocketError, match="not reachable"): - # Actually try to use the client in order to force a connection attempt - lms.get_default_client().list_loaded_models() + with pytest.raises(lms.LMStudioClientError, match="not reachable"): + lms.get_default_client() finally: lms.sync_api._reset_default_client() diff --git a/tests/test_session_errors.py b/tests/test_session_errors.py index f4e3e2a..f5a99c9 100644 --- a/tests/test_session_errors.py +++ b/tests/test_session_errors.py @@ -18,8 +18,8 @@ ) from lmstudio.sync_api import ( SyncLMStudioWebsocket, - SyncSession, - SyncSessionSystem, + _SyncSession, + _SyncSessionSystem, ) from .support import ( @@ -68,7 +68,8 @@ async def test_session_not_started_async(caplog: LogCap) -> None: @pytest.mark.asyncio async def test_session_disconnected_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - client = AsyncClient() + api_host = await AsyncClient.find_default_local_api_host() + client = AsyncClient(api_host) session = _AsyncSessionSystem(client) async with client._task_manager, session: assert session.connected @@ -114,7 +115,7 @@ async def test_session_nonresponsive_port_async(caplog: LogCap) -> None: await check_call_errors_async(session) -def check_call_errors_sync(session: SyncSession) -> None: +def check_call_errors_sync(session: _SyncSession) -> None: # Remote calls are expected to fail when not connected with pytest.raises( LMStudioWebsocketError, @@ -140,7 +141,7 @@ def check_call_errors_sync(session: SyncSession) -> None: def test_session_closed_port_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - session = SyncSessionSystem(Client(closed_api_host())) + session = _SyncSessionSystem(Client(closed_api_host())) # Sessions start out disconnected assert not session.connected # Should get an SDK exception rather than the underlying exception @@ -156,7 +157,7 @@ def test_session_closed_port_sync(caplog: LogCap) -> None: def test_session_nonresponsive_port_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with nonresponsive_api_host() as api_host: - session = SyncSessionSystem(Client(api_host)) + session = _SyncSessionSystem(Client(api_host)) # Sessions start out disconnected assert not session.connected # Should get an SDK exception rather than the underlying exception diff --git a/tests/test_sessions.py b/tests/test_sessions.py index e93c87c..5a71b2d 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -18,14 +18,12 @@ ) from lmstudio.sync_api import ( SyncLMStudioWebsocket, - SyncSession, - SyncSessionSystem, + _SyncSession, + _SyncSessionSystem, ) from lmstudio._ws_impl import AsyncTaskManager from lmstudio._ws_thread import AsyncWebsocketThread -from .support import LOCAL_API_HOST - async def check_connected_async_session(session: _AsyncSession) -> None: assert session.connected @@ -49,7 +47,8 @@ async def check_connected_async_session(session: _AsyncSession) -> None: @pytest.mark.lmstudio async def test_session_cm_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - client = AsyncClient() + api_host = await AsyncClient.find_default_local_api_host() + client = AsyncClient(api_host) session = _AsyncSessionSystem(client) # Sessions start out disconnected assert not session.connected @@ -66,7 +65,7 @@ async def test_session_cm_async(caplog: LogCap) -> None: # Check the synchronous session API -def check_connected_sync_session(session: SyncSession) -> None: +def check_connected_sync_session(session: _SyncSession) -> None: assert session.connected session_ws = session._lmsws assert session_ws is not None @@ -87,7 +86,9 @@ def check_connected_sync_session(session: SyncSession) -> None: @pytest.mark.lmstudio def test_session_cm_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - session = SyncSessionSystem(Client()) + api_host = Client.find_default_local_api_host() + client = Client(api_host) + session = _SyncSessionSystem(client) # Sessions start out disconnected assert not session.connected # Disconnecting should run without error @@ -106,7 +107,9 @@ def test_session_cm_sync(caplog: LogCap) -> None: @pytest.mark.lmstudio def test_implicit_connection_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - session = SyncSessionSystem(Client()) + api_host = Client.find_default_local_api_host() + client = Client(api_host) + session = _SyncSessionSystem(client) # Sessions start out disconnected assert not session.connected try: @@ -123,7 +126,9 @@ def test_implicit_connection_sync(caplog: LogCap) -> None: @pytest.mark.lmstudio def test_implicit_reconnection_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - session = SyncSessionSystem(Client()) + api_host = Client.find_default_local_api_host() + client = Client(api_host) + session = _SyncSessionSystem(client) with session: assert session.connected # Session is disconnected after use @@ -154,9 +159,10 @@ def test_implicit_reconnection_sync(caplog: LogCap) -> None: @pytest.mark.lmstudio async def test_websocket_cm_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) + api_host = await AsyncClient.find_default_local_api_host() auth_details = AsyncClient._format_auth_message() tm = AsyncTaskManager(on_activation=None) - lmsws = _AsyncLMStudioWebsocket(tm, f"http://{LOCAL_API_HOST}/system", auth_details) + lmsws = _AsyncLMStudioWebsocket(tm, f"http://{api_host}/system", auth_details) # SDK client websockets start out disconnected assert not lmsws.connected # Entering the CM opens the websocket if it isn't already open @@ -193,10 +199,9 @@ def ws_thread() -> Generator[AsyncWebsocketThread, None, None]: @pytest.mark.lmstudio def test_websocket_cm_sync(ws_thread: AsyncWebsocketThread, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) + api_host = Client.find_default_local_api_host() auth_details = Client._format_auth_message() - lmsws = SyncLMStudioWebsocket( - ws_thread, f"http://{LOCAL_API_HOST}/system", auth_details - ) + lmsws = SyncLMStudioWebsocket(ws_thread, f"http://{api_host}/system", auth_details) # SDK client websockets start out disconnected assert not lmsws.connected # Entering the CM opens the websocket if it isn't already open