diff --git a/src/lmstudio/_kv_config.py b/src/lmstudio/_kv_config.py index e6fd26f..3c3ec8a 100644 --- a/src/lmstudio/_kv_config.py +++ b/src/lmstudio/_kv_config.py @@ -123,7 +123,7 @@ def update_client_config( "contextLength": ConfigField("contextLength"), } -_SUPPORTED_SERVER_KEYS: dict[str, DictObject] = { +SUPPORTED_SERVER_KEYS: dict[str, DictObject] = { "load": { "gpuSplitConfig": MultiPartField( "gpuOffload", ("mainGpu", "splitStrategy", "disabledGpus") @@ -189,7 +189,7 @@ def _iter_server_keys(*namespaces: str) -> Iterable[tuple[str, ConfigField]]: # Map dotted config field names to their client config field counterparts for namespace in namespaces: scopes: list[tuple[str, DictObject]] = [ - (namespace, _SUPPORTED_SERVER_KEYS[namespace]) + (namespace, SUPPORTED_SERVER_KEYS[namespace]) ] for prefix, scope in scopes: for k, v in scope.items(): @@ -204,6 +204,7 @@ def _iter_server_keys(*namespaces: str) -> Iterable[tuple[str, ConfigField]]: FROM_SERVER_LOAD_LLM = dict(_iter_server_keys("load", "llm.load")) FROM_SERVER_LOAD_EMBEDDING = dict(_iter_server_keys("load", "embedding.load")) FROM_SERVER_PREDICTION = dict(_iter_server_keys("llm.prediction")) +FROM_SERVER_CONFIG = dict(_iter_server_keys(*SUPPORTED_SERVER_KEYS)) # Define mappings to translate client config instances to server KV configs @@ -237,8 +238,26 @@ def dict_from_kvconfig(config: KvConfig) -> DictObject: return {kv.key: kv.value for kv in config.fields} -def dict_from_fields_key(config: DictObject) -> DictObject: - return {kv["key"]: kv["value"] for kv in config.get("fields", [])} +def parse_server_config(server_config: DictObject) -> DictObject: + """Map server config fields to client config fields.""" + result: MutableDictObject = {} + for kv in server_config.get("fields", []): + key = kv["key"] + config_field = FROM_SERVER_CONFIG.get(key, None) + if config_field is None: + # Skip unknown keys (server might be newer than the SDK) + continue + value = kv["value"] + config_field.update_client_config(result, value) + return result + + +def parse_llm_load_config(server_config: DictObject) -> LlmLoadModelConfig: + return LlmLoadModelConfig._from_any_api_dict(parse_server_config(server_config)) + + +def parse_prediction_config(server_config: DictObject) -> LlmPredictionConfig: + return LlmPredictionConfig._from_any_api_dict(parse_server_config(server_config)) def _api_override_kv_config_stack( diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index d4d4752..a9912c4 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -43,6 +43,7 @@ _LocalFileData, ) from .json_api import ( + AnyLoadConfig, AnyModelSpecifier, AvailableModelBase, ChannelEndpoint, @@ -93,7 +94,7 @@ _model_spec_to_api_dict, _redact_json, ) -from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key +from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config from ._sdk_models import ( EmbeddingRpcEmbedStringParameter, EmbeddingRpcTokenizeParameter, @@ -693,7 +694,9 @@ def _system_session(self) -> AsyncSessionSystem: def _files_session(self) -> _AsyncSessionFiles: return self._client.files - async def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject: + async def _get_load_config( + self, model_specifier: AnyModelSpecifier + ) -> AnyLoadConfig: """Get the model load config for the specified model.""" # Note that the configuration reported here uses the *server* config names, # not the attributes used to set the configuration in the client SDK @@ -703,7 +706,8 @@ async def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObje } ) config = await self.remote_call("getLoadConfig", params) - return dict_from_fields_key(config) + result_type = self._API_TYPES.MODEL_LOAD_CONFIG + return result_type._from_any_api_dict(parse_server_config(config)) async def _get_api_model_info(self, model_specifier: AnyModelSpecifier) -> Any: """Get the raw model info (if any) for a model matching the given criteria.""" @@ -1158,7 +1162,9 @@ async def _embed( ) -class AsyncModelHandle(ModelHandleBase[TAsyncSessionModel]): +class AsyncModelHandle( + Generic[TAsyncSessionModel], ModelHandleBase[TAsyncSessionModel] +): """Reference to a loaded LM Studio model.""" @sdk_public_api_async() @@ -1171,9 +1177,8 @@ async def get_info(self) -> ModelInstanceInfo: """Get the model info for this model.""" return await self._session.get_model_info(self.identifier) - # Private until this API can emit the client config types @sdk_public_api_async() - async def _get_load_config(self) -> DictObject: + async def get_load_config(self) -> AnyLoadConfig: """Get the model load config for this model.""" return await self._session._get_load_config(self.identifier) diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index d8b92fa..bd9b143 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -59,8 +59,9 @@ from ._kv_config import ( TLoadConfig, TLoadConfigDict, - dict_from_fields_key, load_config_to_kv_config_stack, + parse_llm_load_config, + parse_prediction_config, prediction_config_to_kv_config_stack, ) from ._sdk_models import ( @@ -128,6 +129,7 @@ # implicitly as part of the top-level `lmstudio` API. __all__ = [ "ActResult", + "AnyLoadConfig", "AnyModelSpecifier", "DownloadFinalizedCallback", "DownloadProgressCallback", @@ -180,6 +182,7 @@ DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject +AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig GetOrLoadChannelRequest: TypeAlias = ( EmbeddingChannelGetOrLoadCreationParameter | LlmChannelGetOrLoadCreationParameter @@ -441,12 +444,9 @@ class PredictionResult(Generic[TPrediction]): parsed: TPrediction # dict for structured predictions, str otherwise stats: LlmPredictionStats # Statistics about the prediction process model_info: LlmInfo # Information about the model used - structured: bool = field(init=False) # Whether the result is structured or not - # Note that the configuration reported here uses the *server* config names, - # not the attributes used to set the configuration in the client SDK - # Private until these attributes store the client config types - _load_config: DictObject # The configuration used to load the model - _prediction_config: DictObject # The configuration used for the prediction + structured: bool = field(init=False) # Whether the result is structured or not + load_config: LlmLoadModelConfig # The configuration used to load the model + prediction_config: LlmPredictionConfig # The configuration used for the prediction # fmt: on def __post_init__(self) -> None: @@ -1262,8 +1262,8 @@ def iter_message_events( parsed=parsed_content, stats=LlmPredictionStats._from_any_api_dict(stats), model_info=LlmInfo._from_any_api_dict(model_info), - _load_config=dict_from_fields_key(load_kvconfig), - _prediction_config=dict_from_fields_key(prediction_kvconfig), + load_config=parse_llm_load_config(load_kvconfig), + prediction_config=parse_prediction_config(prediction_kvconfig), ) ) case unmatched: @@ -1477,19 +1477,19 @@ def model_info(self) -> LlmInfo | None: # Private until this API can emit the client config types @property - def _load_config(self) -> DictObject | None: + def _load_config(self) -> LlmLoadModelConfig | None: """Get the load configuration used for the current prediction if available.""" if self._final_result is None: return None - return self._final_result._load_config + return self._final_result.load_config # Private until this API can emit the client config types @property - def _prediction_config(self) -> DictObject | None: + def _prediction_config(self) -> LlmPredictionConfig | None: """Get the prediction configuration used for the current prediction if available.""" if self._final_result is None: return None - return self._final_result._prediction_config + return self._final_result.prediction_config @sdk_public_api() def result(self) -> PredictionResult[TPrediction]: diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index 486a705..474b09b 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -66,6 +66,7 @@ ) from .json_api import ( ActResult, + AnyLoadConfig, AnyModelSpecifier, AvailableModelBase, ChannelEndpoint, @@ -121,7 +122,7 @@ _model_spec_to_api_dict, _redact_json, ) -from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key +from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config from ._sdk_models import ( EmbeddingRpcEmbedStringParameter, EmbeddingRpcTokenizeParameter, @@ -866,7 +867,7 @@ def _system_session(self) -> SyncSessionSystem: def _files_session(self) -> _SyncSessionFiles: return self._client.files - def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject: + def _get_load_config(self, model_specifier: AnyModelSpecifier) -> AnyLoadConfig: """Get the model load config for the specified model.""" # Note that the configuration reported here uses the *server* config names, # not the attributes used to set the configuration in the client SDK @@ -876,7 +877,8 @@ def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject: } ) config = self.remote_call("getLoadConfig", params) - return dict_from_fields_key(config) + result_type = self._API_TYPES.MODEL_LOAD_CONFIG + return result_type._from_any_api_dict(parse_server_config(config)) def _get_api_model_info(self, model_specifier: AnyModelSpecifier) -> Any: """Get the raw model info (if any) for a model matching the given criteria.""" @@ -1339,7 +1341,7 @@ def get_info(self) -> ModelInstanceInfo: # Private until this API can emit the client config types @sdk_public_api() - def _get_load_config(self) -> DictObject: + def _get_load_config(self) -> AnyLoadConfig: """Get the model load config for this model.""" return self._session._get_load_config(self.identifier) diff --git a/tests/async/test_embedding_async.py b/tests/async/test_embedding_async.py index d61acca..6432ae2 100644 --- a/tests/async/test_embedding_async.py +++ b/tests/async/test_embedding_async.py @@ -6,7 +6,7 @@ import pytest from pytest import LogCaptureFixture as LogCap -from lmstudio import AsyncClient, LMStudioModelNotFoundError +from lmstudio import AsyncClient, EmbeddingLoadModelConfig, LMStudioModelNotFoundError from ..support import ( EXPECTED_EMBEDDING, @@ -114,7 +114,7 @@ async def test_get_load_config_async(model_id: str, caplog: LogCap) -> None: response = await client.embedding._get_load_config(model_id) logging.info(f"Load config response: {response}") assert response - assert isinstance(response, dict) + assert isinstance(response, EmbeddingLoadModelConfig) @pytest.mark.asyncio diff --git a/tests/async/test_inference_async.py b/tests/async/test_inference_async.py index f653e89..21ad452 100644 --- a/tests/async/test_inference_async.py +++ b/tests/async/test_inference_async.py @@ -15,6 +15,8 @@ Chat, DictSchema, LlmInfo, + LlmLoadModelConfig, + LlmPredictionConfig, LlmPredictionConfigDict, LlmPredictionFragment, LlmPredictionStats, @@ -269,12 +271,12 @@ async def test_complete_prediction_metadata_async(caplog: LogCap) -> None: logging.info(f"LLM response: {response.content!r}") assert response.stats assert response.model_info - assert response._load_config - assert response._prediction_config + assert response.load_config + assert response.prediction_config assert isinstance(response.stats, LlmPredictionStats) assert isinstance(response.model_info, LlmInfo) - assert isinstance(response._load_config, dict) - assert isinstance(response._prediction_config, dict) + assert isinstance(response.load_config, LlmLoadModelConfig) + assert isinstance(response.prediction_config, LlmPredictionConfig) @pytest.mark.asyncio diff --git a/tests/async/test_llm_async.py b/tests/async/test_llm_async.py index acb5fbc..949ca16 100644 --- a/tests/async/test_llm_async.py +++ b/tests/async/test_llm_async.py @@ -5,7 +5,7 @@ import pytest from pytest import LogCaptureFixture as LogCap -from lmstudio import AsyncClient, history +from lmstudio import AsyncClient, LlmLoadModelConfig, history from ..support import EXPECTED_LLM, EXPECTED_LLM_ID @@ -96,7 +96,7 @@ async def test_get_load_config_async(model_id: str, caplog: LogCap) -> None: response = await client.llm._get_load_config(model_id) logging.info(f"Load config response: {response}") assert response - assert isinstance(response, dict) + assert isinstance(response, LlmLoadModelConfig) @pytest.mark.asyncio diff --git a/tests/sync/test_embedding_sync.py b/tests/sync/test_embedding_sync.py index e6dc521..f1a11f7 100644 --- a/tests/sync/test_embedding_sync.py +++ b/tests/sync/test_embedding_sync.py @@ -13,7 +13,7 @@ import pytest from pytest import LogCaptureFixture as LogCap -from lmstudio import Client, LMStudioModelNotFoundError +from lmstudio import Client, EmbeddingLoadModelConfig, LMStudioModelNotFoundError from ..support import ( EXPECTED_EMBEDDING, @@ -115,7 +115,7 @@ def test_get_load_config_sync(model_id: str, caplog: LogCap) -> None: response = client.embedding._get_load_config(model_id) logging.info(f"Load config response: {response}") assert response - assert isinstance(response, dict) + assert isinstance(response, EmbeddingLoadModelConfig) @pytest.mark.lmstudio diff --git a/tests/sync/test_inference_sync.py b/tests/sync/test_inference_sync.py index b0bd21b..4e0ec07 100644 --- a/tests/sync/test_inference_sync.py +++ b/tests/sync/test_inference_sync.py @@ -22,6 +22,8 @@ Chat, DictSchema, LlmInfo, + LlmLoadModelConfig, + LlmPredictionConfig, LlmPredictionConfigDict, LlmPredictionFragment, LlmPredictionStats, @@ -264,12 +266,12 @@ def test_complete_prediction_metadata_sync(caplog: LogCap) -> None: logging.info(f"LLM response: {response.content!r}") assert response.stats assert response.model_info - assert response._load_config - assert response._prediction_config + assert response.load_config + assert response.prediction_config assert isinstance(response.stats, LlmPredictionStats) assert isinstance(response.model_info, LlmInfo) - assert isinstance(response._load_config, dict) - assert isinstance(response._prediction_config, dict) + assert isinstance(response.load_config, LlmLoadModelConfig) + assert isinstance(response.prediction_config, LlmPredictionConfig) @pytest.mark.lmstudio diff --git a/tests/sync/test_llm_sync.py b/tests/sync/test_llm_sync.py index d0dad09..b3f4ece 100644 --- a/tests/sync/test_llm_sync.py +++ b/tests/sync/test_llm_sync.py @@ -12,7 +12,7 @@ import pytest from pytest import LogCaptureFixture as LogCap -from lmstudio import Client, history +from lmstudio import Client, LlmLoadModelConfig, history from ..support import EXPECTED_LLM, EXPECTED_LLM_ID @@ -96,7 +96,7 @@ def test_get_load_config_sync(model_id: str, caplog: LogCap) -> None: response = client.llm._get_load_config(model_id) logging.info(f"Load config response: {response}") assert response - assert isinstance(response, dict) + assert isinstance(response, LlmLoadModelConfig) @pytest.mark.lmstudio diff --git a/tests/test_history.py b/tests/test_history.py index c73074c..233cdfd 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -24,6 +24,8 @@ ) from lmstudio.json_api import ( LlmInfo, + LlmLoadModelConfig, + LlmPredictionConfig, LlmPredictionStats, PredictionResult, TPrediction, @@ -347,8 +349,8 @@ def _make_prediction_result(data: TPrediction) -> PredictionResult[TPrediction]: trained_for_tool_use=False, max_context_length=32, ), - _load_config={}, - _prediction_config={}, + load_config=LlmLoadModelConfig(), + prediction_config=LlmPredictionConfig(), )