diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 83816be..05435f9 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1392,7 +1392,8 @@ def request_tool_call( def _call_requested_tool() -> ToolCallResultData: call_result = implementation(**kwds) return ToolCallResultData( - content=json.dumps(call_result, ensure_ascii=False), tool_call_id=tool_call_id + content=json.dumps(call_result, ensure_ascii=False), + tool_call_id=tool_call_id, ) return _call_requested_tool diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index 3d87dbf..b338791 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -139,6 +139,7 @@ "LLM", "SyncModelHandle", "PredictionStream", + "configure_default_client", "get_default_client", "embedding_model", "list_downloaded_models", @@ -1606,22 +1607,43 @@ def list_loaded_models( # Convenience API +_default_api_host = None _default_client: Client | None = None +@sdk_public_api() +def configure_default_client(api_host: str) -> None: + """Set the server API host for the default global client (without creating the client).""" + global _default_api_host + if _default_client is not None: + raise LMStudioClientError( + "Default client is already created, cannot set its API host." + ) + _default_api_host = api_host + + @sdk_public_api() def get_default_client(api_host: str | None = None) -> Client: """Get the default global client (creating it if necessary).""" global _default_client + if api_host is not None: + # This will raise an exception if the client already exists + configure_default_client(api_host) if _default_client is None: - _default_client = Client(api_host) - elif api_host is not None: - raise LMStudioClientError( - "Default session already connected, cannot set API host." - ) + _default_client = Client(_default_api_host) return _default_client +def _reset_default_client() -> None: + # Allow the test suite to reset the client without + # having to poke directly at the module's internals + global _default_api_host, _default_client + previous_client = _default_client + _default_api_host = _default_client = None + if previous_client is not None: + previous_client.close() + + @sdk_public_api() def llm( model_key: str | None = None, diff --git a/tests/test_convenience_api.py b/tests/test_convenience_api.py index c2bc031..05eaf87 100644 --- a/tests/test_convenience_api.py +++ b/tests/test_convenience_api.py @@ -13,6 +13,7 @@ EXPECTED_VLM_ID, IMAGE_FILEPATH, TOOL_LLM_ID, + closed_api_host, ) @@ -20,6 +21,36 @@ def test_get_default_client() -> None: client = lms.get_default_client() assert isinstance(client, lms.Client) + # Setting the API host after creation is disallowed (even if it is consistent) + with pytest.raises(lms.LMStudioClientError, match="already created"): + lms.get_default_client("localhost:1234") + # Ensure configured API host is used + lms.sync_api._reset_default_client() + try: + with pytest.raises(lms.LMStudioWebsocketError, match="not reachable"): + # Actually try to use the client in order to force a connection attempt + lms.get_default_client(closed_api_host()).list_loaded_models() + finally: + lms.sync_api._reset_default_client() + + +@pytest.mark.lmstudio +def test_configure_default_client() -> None: + # Ensure the default client already exists + client = lms.get_default_client() + assert isinstance(client, lms.Client) + # Setting the API host after creation is disallowed (even if it is consistent) + with pytest.raises(lms.LMStudioClientError, match="already created"): + lms.configure_default_client("localhost:1234") + # Ensure configured API host is used + lms.sync_api._reset_default_client() + try: + lms.configure_default_client(closed_api_host()) + with pytest.raises(lms.LMStudioWebsocketError, match="not reachable"): + # Actually try to use the client in order to force a connection attempt + lms.get_default_client().list_loaded_models() + finally: + lms.sync_api._reset_default_client() @pytest.mark.lmstudio