Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,8 @@ def request_tool_call(
def _call_requested_tool() -> ToolCallResultData:
call_result = implementation(**kwds)
return ToolCallResultData(
content=json.dumps(call_result, ensure_ascii=False), tool_call_id=tool_call_id
content=json.dumps(call_result, ensure_ascii=False),
tool_call_id=tool_call_id,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code formatting fix from a previous PR

)

return _call_requested_tool
Expand Down
32 changes: 27 additions & 5 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
"LLM",
"SyncModelHandle",
"PredictionStream",
"configure_default_client",
"get_default_client",
"embedding_model",
"list_downloaded_models",
Expand Down Expand Up @@ -1606,22 +1607,43 @@ def list_loaded_models(


# Convenience API
_default_api_host = None
_default_client: Client | None = None


@sdk_public_api()
def configure_default_client(api_host: str) -> None:
"""Set the server API host for the default global client (without creating the client)."""
global _default_api_host
if _default_client is not None:
raise LMStudioClientError(
"Default client is already created, cannot set its API host."
)
_default_api_host = api_host


@sdk_public_api()
def get_default_client(api_host: str | None = None) -> Client:
"""Get the default global client (creating it if necessary)."""
global _default_client
if api_host is not None:
# This will raise an exception if the client already exists
configure_default_client(api_host)
if _default_client is None:
_default_client = Client(api_host)
elif api_host is not None:
raise LMStudioClientError(
"Default session already connected, cannot set API host."
)
_default_client = Client(_default_api_host)
return _default_client


def _reset_default_client() -> None:
# Allow the test suite to reset the client without
# having to poke directly at the module's internals
global _default_client
previous_client = _default_client
_default_client = None
if previous_client is not None:
previous_client.close()


@sdk_public_api()
def llm(
model_key: str | None = None,
Expand Down
31 changes: 31 additions & 0 deletions tests/test_convenience_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,44 @@
EXPECTED_VLM_ID,
IMAGE_FILEPATH,
TOOL_LLM_ID,
closed_api_host,
)


@pytest.mark.lmstudio
def test_get_default_client() -> None:
client = lms.get_default_client()
assert isinstance(client, lms.Client)
# Setting the API host after creation is disallowed (even if it is consistent)
with pytest.raises(lms.LMStudioClientError, match="already created"):
lms.get_default_client("localhost:1234")
# Ensure configured API host is used
lms.sync_api._reset_default_client()
try:
with pytest.raises(lms.LMStudioWebsocketError, match="not reachable"):
# Actually try to use the client in order to force a connection attempt
lms.get_default_client(closed_api_host()).list_loaded_models()
finally:
lms.sync_api._reset_default_client()


@pytest.mark.lmstudio
def test_configure_default_client() -> None:
# Ensure the default client already exists
client = lms.get_default_client()
assert isinstance(client, lms.Client)
# Setting the API host after creation is disallowed (even if it is consistent)
with pytest.raises(lms.LMStudioClientError, match="already created"):
lms.configure_default_client("localhost:1234")
# Ensure configured API host is used
lms.sync_api._reset_default_client()
try:
lms.configure_default_client(closed_api_host())
with pytest.raises(lms.LMStudioWebsocketError, match="not reachable"):
# Actually try to use the client in order to force a connection attempt
lms.get_default_client().list_loaded_models()
finally:
lms.sync_api._reset_default_client()


@pytest.mark.lmstudio
Expand Down