Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/lmstudio/_kv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def update_client_config(
"contextLength": ConfigField("contextLength"),
}

_SUPPORTED_SERVER_KEYS: dict[str, DictObject] = {
SUPPORTED_SERVER_KEYS: dict[str, DictObject] = {
"load": {
"gpuSplitConfig": MultiPartField(
"gpuOffload", ("mainGpu", "splitStrategy", "disabledGpus")
Expand Down Expand Up @@ -189,7 +189,7 @@ def _iter_server_keys(*namespaces: str) -> Iterable[tuple[str, ConfigField]]:
# Map dotted config field names to their client config field counterparts
for namespace in namespaces:
scopes: list[tuple[str, DictObject]] = [
(namespace, _SUPPORTED_SERVER_KEYS[namespace])
(namespace, SUPPORTED_SERVER_KEYS[namespace])
]
for prefix, scope in scopes:
for k, v in scope.items():
Expand All @@ -204,6 +204,7 @@ def _iter_server_keys(*namespaces: str) -> Iterable[tuple[str, ConfigField]]:
FROM_SERVER_LOAD_LLM = dict(_iter_server_keys("load", "llm.load"))
FROM_SERVER_LOAD_EMBEDDING = dict(_iter_server_keys("load", "embedding.load"))
FROM_SERVER_PREDICTION = dict(_iter_server_keys("llm.prediction"))
FROM_SERVER_CONFIG = dict(_iter_server_keys(*SUPPORTED_SERVER_KEYS))


# Define mappings to translate client config instances to server KV configs
Expand Down Expand Up @@ -237,8 +238,26 @@ def dict_from_kvconfig(config: KvConfig) -> DictObject:
return {kv.key: kv.value for kv in config.fields}


def dict_from_fields_key(config: DictObject) -> DictObject:
return {kv["key"]: kv["value"] for kv in config.get("fields", [])}
def parse_server_config(server_config: DictObject) -> DictObject:
"""Map server config fields to client config fields."""
result: MutableDictObject = {}
for kv in server_config.get("fields", []):
key = kv["key"]
config_field = FROM_SERVER_CONFIG.get(key, None)
if config_field is None:
# Skip unknown keys (server might be newer than the SDK)
continue
value = kv["value"]
config_field.update_client_config(result, value)
return result


def parse_llm_load_config(server_config: DictObject) -> LlmLoadModelConfig:
return LlmLoadModelConfig._from_any_api_dict(parse_server_config(server_config))


def parse_prediction_config(server_config: DictObject) -> LlmPredictionConfig:
return LlmPredictionConfig._from_any_api_dict(parse_server_config(server_config))


def _api_override_kv_config_stack(
Expand Down
17 changes: 11 additions & 6 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_LocalFileData,
)
from .json_api import (
AnyLoadConfig,
AnyModelSpecifier,
AvailableModelBase,
ChannelEndpoint,
Expand Down Expand Up @@ -93,7 +94,7 @@
_model_spec_to_api_dict,
_redact_json,
)
from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
from ._sdk_models import (
EmbeddingRpcEmbedStringParameter,
EmbeddingRpcTokenizeParameter,
Expand Down Expand Up @@ -693,7 +694,9 @@ def _system_session(self) -> AsyncSessionSystem:
def _files_session(self) -> _AsyncSessionFiles:
return self._client.files

async def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject:
async def _get_load_config(
self, model_specifier: AnyModelSpecifier
) -> AnyLoadConfig:
"""Get the model load config for the specified model."""
# Note that the configuration reported here uses the *server* config names,
# not the attributes used to set the configuration in the client SDK
Expand All @@ -703,7 +706,8 @@ async def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObje
}
)
config = await self.remote_call("getLoadConfig", params)
return dict_from_fields_key(config)
result_type = self._API_TYPES.MODEL_LOAD_CONFIG
return result_type._from_any_api_dict(parse_server_config(config))

async def _get_api_model_info(self, model_specifier: AnyModelSpecifier) -> Any:
"""Get the raw model info (if any) for a model matching the given criteria."""
Expand Down Expand Up @@ -1158,7 +1162,9 @@ async def _embed(
)


class AsyncModelHandle(ModelHandleBase[TAsyncSessionModel]):
class AsyncModelHandle(
Generic[TAsyncSessionModel], ModelHandleBase[TAsyncSessionModel]
):
"""Reference to a loaded LM Studio model."""

@sdk_public_api_async()
Expand All @@ -1171,9 +1177,8 @@ async def get_info(self) -> ModelInstanceInfo:
"""Get the model info for this model."""
return await self._session.get_model_info(self.identifier)

# Private until this API can emit the client config types
@sdk_public_api_async()
async def _get_load_config(self) -> DictObject:
async def get_load_config(self) -> AnyLoadConfig:
"""Get the model load config for this model."""
return await self._session._get_load_config(self.identifier)

Expand Down
26 changes: 13 additions & 13 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@
from ._kv_config import (
TLoadConfig,
TLoadConfigDict,
dict_from_fields_key,
load_config_to_kv_config_stack,
parse_llm_load_config,
parse_prediction_config,
prediction_config_to_kv_config_stack,
)
from ._sdk_models import (
Expand Down Expand Up @@ -128,6 +129,7 @@
# implicitly as part of the top-level `lmstudio` API.
__all__ = [
"ActResult",
"AnyLoadConfig",
"AnyModelSpecifier",
"DownloadFinalizedCallback",
"DownloadProgressCallback",
Expand Down Expand Up @@ -180,6 +182,7 @@
DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour

AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject
AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig

GetOrLoadChannelRequest: TypeAlias = (
EmbeddingChannelGetOrLoadCreationParameter | LlmChannelGetOrLoadCreationParameter
Expand Down Expand Up @@ -441,12 +444,9 @@ class PredictionResult(Generic[TPrediction]):
parsed: TPrediction # dict for structured predictions, str otherwise
stats: LlmPredictionStats # Statistics about the prediction process
model_info: LlmInfo # Information about the model used
structured: bool = field(init=False) # Whether the result is structured or not
# Note that the configuration reported here uses the *server* config names,
# not the attributes used to set the configuration in the client SDK
# Private until these attributes store the client config types
_load_config: DictObject # The configuration used to load the model
_prediction_config: DictObject # The configuration used for the prediction
structured: bool = field(init=False) # Whether the result is structured or not
load_config: LlmLoadModelConfig # The configuration used to load the model
prediction_config: LlmPredictionConfig # The configuration used for the prediction
# fmt: on

def __post_init__(self) -> None:
Expand Down Expand Up @@ -1262,8 +1262,8 @@ def iter_message_events(
parsed=parsed_content,
stats=LlmPredictionStats._from_any_api_dict(stats),
model_info=LlmInfo._from_any_api_dict(model_info),
_load_config=dict_from_fields_key(load_kvconfig),
_prediction_config=dict_from_fields_key(prediction_kvconfig),
load_config=parse_llm_load_config(load_kvconfig),
prediction_config=parse_prediction_config(prediction_kvconfig),
)
)
case unmatched:
Expand Down Expand Up @@ -1477,19 +1477,19 @@ def model_info(self) -> LlmInfo | None:

# Private until this API can emit the client config types
@property
def _load_config(self) -> DictObject | None:
def _load_config(self) -> LlmLoadModelConfig | None:
"""Get the load configuration used for the current prediction if available."""
if self._final_result is None:
return None
return self._final_result._load_config
return self._final_result.load_config

# Private until this API can emit the client config types
@property
def _prediction_config(self) -> DictObject | None:
def _prediction_config(self) -> LlmPredictionConfig | None:
"""Get the prediction configuration used for the current prediction if available."""
if self._final_result is None:
return None
return self._final_result._prediction_config
return self._final_result.prediction_config

@sdk_public_api()
def result(self) -> PredictionResult[TPrediction]:
Expand Down
10 changes: 6 additions & 4 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
)
from .json_api import (
ActResult,
AnyLoadConfig,
AnyModelSpecifier,
AvailableModelBase,
ChannelEndpoint,
Expand Down Expand Up @@ -121,7 +122,7 @@
_model_spec_to_api_dict,
_redact_json,
)
from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
from ._sdk_models import (
EmbeddingRpcEmbedStringParameter,
EmbeddingRpcTokenizeParameter,
Expand Down Expand Up @@ -866,7 +867,7 @@ def _system_session(self) -> SyncSessionSystem:
def _files_session(self) -> _SyncSessionFiles:
return self._client.files

def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject:
def _get_load_config(self, model_specifier: AnyModelSpecifier) -> AnyLoadConfig:
"""Get the model load config for the specified model."""
# Note that the configuration reported here uses the *server* config names,
# not the attributes used to set the configuration in the client SDK
Expand All @@ -876,7 +877,8 @@ def _get_load_config(self, model_specifier: AnyModelSpecifier) -> DictObject:
}
)
config = self.remote_call("getLoadConfig", params)
return dict_from_fields_key(config)
result_type = self._API_TYPES.MODEL_LOAD_CONFIG
return result_type._from_any_api_dict(parse_server_config(config))

def _get_api_model_info(self, model_specifier: AnyModelSpecifier) -> Any:
"""Get the raw model info (if any) for a model matching the given criteria."""
Expand Down Expand Up @@ -1339,7 +1341,7 @@ def get_info(self) -> ModelInstanceInfo:

# Private until this API can emit the client config types
@sdk_public_api()
def _get_load_config(self) -> DictObject:
def _get_load_config(self) -> AnyLoadConfig:
"""Get the model load config for this model."""
return self._session._get_load_config(self.identifier)

Expand Down
4 changes: 2 additions & 2 deletions tests/async/test_embedding_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from pytest import LogCaptureFixture as LogCap

from lmstudio import AsyncClient, LMStudioModelNotFoundError
from lmstudio import AsyncClient, EmbeddingLoadModelConfig, LMStudioModelNotFoundError

from ..support import (
EXPECTED_EMBEDDING,
Expand Down Expand Up @@ -114,7 +114,7 @@ async def test_get_load_config_async(model_id: str, caplog: LogCap) -> None:
response = await client.embedding._get_load_config(model_id)
logging.info(f"Load config response: {response}")
assert response
assert isinstance(response, dict)
assert isinstance(response, EmbeddingLoadModelConfig)


@pytest.mark.asyncio
Expand Down
10 changes: 6 additions & 4 deletions tests/async/test_inference_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Chat,
DictSchema,
LlmInfo,
LlmLoadModelConfig,
LlmPredictionConfig,
LlmPredictionConfigDict,
LlmPredictionFragment,
LlmPredictionStats,
Expand Down Expand Up @@ -269,12 +271,12 @@ async def test_complete_prediction_metadata_async(caplog: LogCap) -> None:
logging.info(f"LLM response: {response.content!r}")
assert response.stats
assert response.model_info
assert response._load_config
assert response._prediction_config
assert response.load_config
assert response.prediction_config
assert isinstance(response.stats, LlmPredictionStats)
assert isinstance(response.model_info, LlmInfo)
assert isinstance(response._load_config, dict)
assert isinstance(response._prediction_config, dict)
assert isinstance(response.load_config, LlmLoadModelConfig)
assert isinstance(response.prediction_config, LlmPredictionConfig)


@pytest.mark.asyncio
Expand Down
4 changes: 2 additions & 2 deletions tests/async/test_llm_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from pytest import LogCaptureFixture as LogCap

from lmstudio import AsyncClient, history
from lmstudio import AsyncClient, LlmLoadModelConfig, history

from ..support import EXPECTED_LLM, EXPECTED_LLM_ID

Expand Down Expand Up @@ -96,7 +96,7 @@ async def test_get_load_config_async(model_id: str, caplog: LogCap) -> None:
response = await client.llm._get_load_config(model_id)
logging.info(f"Load config response: {response}")
assert response
assert isinstance(response, dict)
assert isinstance(response, LlmLoadModelConfig)


@pytest.mark.asyncio
Expand Down
4 changes: 2 additions & 2 deletions tests/sync/test_embedding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest
from pytest import LogCaptureFixture as LogCap

from lmstudio import Client, LMStudioModelNotFoundError
from lmstudio import Client, EmbeddingLoadModelConfig, LMStudioModelNotFoundError

from ..support import (
EXPECTED_EMBEDDING,
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_get_load_config_sync(model_id: str, caplog: LogCap) -> None:
response = client.embedding._get_load_config(model_id)
logging.info(f"Load config response: {response}")
assert response
assert isinstance(response, dict)
assert isinstance(response, EmbeddingLoadModelConfig)


@pytest.mark.lmstudio
Expand Down
10 changes: 6 additions & 4 deletions tests/sync/test_inference_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
Chat,
DictSchema,
LlmInfo,
LlmLoadModelConfig,
LlmPredictionConfig,
LlmPredictionConfigDict,
LlmPredictionFragment,
LlmPredictionStats,
Expand Down Expand Up @@ -264,12 +266,12 @@ def test_complete_prediction_metadata_sync(caplog: LogCap) -> None:
logging.info(f"LLM response: {response.content!r}")
assert response.stats
assert response.model_info
assert response._load_config
assert response._prediction_config
assert response.load_config
assert response.prediction_config
assert isinstance(response.stats, LlmPredictionStats)
assert isinstance(response.model_info, LlmInfo)
assert isinstance(response._load_config, dict)
assert isinstance(response._prediction_config, dict)
assert isinstance(response.load_config, LlmLoadModelConfig)
assert isinstance(response.prediction_config, LlmPredictionConfig)


@pytest.mark.lmstudio
Expand Down
4 changes: 2 additions & 2 deletions tests/sync/test_llm_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from pytest import LogCaptureFixture as LogCap

from lmstudio import Client, history
from lmstudio import Client, LlmLoadModelConfig, history

from ..support import EXPECTED_LLM, EXPECTED_LLM_ID

Expand Down Expand Up @@ -96,7 +96,7 @@ def test_get_load_config_sync(model_id: str, caplog: LogCap) -> None:
response = client.llm._get_load_config(model_id)
logging.info(f"Load config response: {response}")
assert response
assert isinstance(response, dict)
assert isinstance(response, LlmLoadModelConfig)


@pytest.mark.lmstudio
Expand Down
6 changes: 4 additions & 2 deletions tests/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
)
from lmstudio.json_api import (
LlmInfo,
LlmLoadModelConfig,
LlmPredictionConfig,
LlmPredictionStats,
PredictionResult,
TPrediction,
Expand Down Expand Up @@ -347,8 +349,8 @@ def _make_prediction_result(data: TPrediction) -> PredictionResult[TPrediction]:
trained_for_tool_use=False,
max_context_length=32,
),
_load_config={},
_prediction_config={},
load_config=LlmLoadModelConfig(),
prediction_config=LlmPredictionConfig(),
)


Expand Down
Loading