Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import json
import uuid
import warnings

from abc import ABC, abstractmethod
from contextlib import contextmanager
Expand Down Expand Up @@ -152,7 +153,7 @@
"LMStudioPredictionError",
"LMStudioPresetNotFoundError",
"LMStudioServerError",
"LMStudioUnknownMessageError",
"LMStudioUnknownMessageWarning",
"LMStudioWebsocketError",
"ModelInfo",
"ModelInstanceInfo",
Expand Down Expand Up @@ -414,7 +415,7 @@ class LMStudioClientError(LMStudioError):


@sdk_public_type
class LMStudioUnknownMessageError(LMStudioClientError):
class LMStudioUnknownMessageWarning(LMStudioClientError, UserWarning):
"""Client has received a message in a format it wasn't expecting."""


Expand Down Expand Up @@ -699,11 +700,19 @@ def result(self) -> T:
assert self._result is not None
return self._result

def raise_unknown_message_error(self, unknown_message: Any) -> NoReturn:
# TODO: improve forward compatibility by switching this to use warnings.warn
# instead of failing immediately for all unknown messages
raise LMStudioUnknownMessageError(
f"{self._NOTICE_PREFIX} unexpected message contents: {unknown_message!r}"
def report_unknown_message(self, unknown_message: Any) -> None:
# By default, each unique unknown message will be reported once per
# calling code location, NOT per channel instance. This is reasonable,
# since it generally indicates an SDK/server version compatibility issue,
# not a problem with any specific instance
# Potentially useful warnings filters:
# * Always show: "always:LMStudioUnknownMessageWarning"
# * Never show: "ignore:LMStudioUnknownMessageWarning"
# * Client exception: "error:LMStudioUnknownMessageWarning"
warnings.warn(
f"{self._NOTICE_PREFIX} unexpected message contents: {unknown_message!r}",
LMStudioUnknownMessageWarning,
stacklevel=2, # Handle based on caller's code location
)

# See ChannelHandler below for more details on the routing of received messages
Expand Down Expand Up @@ -797,7 +806,7 @@ def iter_message_events(
case {"type": "success", "defaultIdentifier": str(default_identifier)}:
yield self._set_result(default_identifier)
case unmatched:
self.raise_unknown_message_error(unmatched)
self.report_unknown_message(unmatched)

def handle_rx_event(self, event: ModelDownloadRxEvent) -> None:
match event:
Expand Down Expand Up @@ -921,7 +930,7 @@ def iter_message_events(
)
yield self._set_result(result)
case unmatched:
self.raise_unknown_message_error(unmatched)
self.report_unknown_message(unmatched)

def handle_rx_event(self, event: ModelLoadingRxEvent) -> None:
match event:
Expand Down Expand Up @@ -1290,7 +1299,7 @@ def iter_message_events(
)
)
case unmatched:
self.raise_unknown_message_error(unmatched)
self.report_unknown_message(unmatched)

def handle_rx_event(self, event: PredictionRxEvent) -> None:
match event:
Expand Down