Skip to content

Commit 0d9ef67

Browse files
authored
Fix GPU offload split configuration (#89)
* client API and server API describe strategies differently (`favorMainGpu` on the client is a specific case of `priorityOrder` on the LM Studio server) * multi-part config fields may only be partially populated * also add test cases for the reverse server to client config mapping (this picked up an error in handling checkbox fields) Closes #88
1 parent 6be1164 commit 0d9ef67

File tree

2 files changed

+162
-21
lines changed

2 files changed

+162
-21
lines changed

src/lmstudio/_kv_config.py

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,32 @@
33
# Known KV config settings are defined in
44
# https://github.com/lmstudio-ai/lmstudio-js/blob/main/packages/lms-kv-config/src/schema.ts
55
from dataclasses import dataclass
6-
from typing import Any, Container, Iterable, Sequence, Type, TypeAlias, TypeVar, cast
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Callable,
10+
Container,
11+
Iterable,
12+
Sequence,
13+
Type,
14+
TypeAlias,
15+
TypeVar,
16+
cast,
17+
get_args,
18+
)
19+
from typing_extensions import (
20+
# Native in 3.11+
21+
assert_never,
22+
)
723

824
from .sdk_api import LMStudioValueError
925
from .schemas import DictObject, DictSchema, ModelSchema, MutableDictObject
1026
from ._sdk_models import (
1127
EmbeddingLoadModelConfig,
1228
EmbeddingLoadModelConfigDict,
29+
GpuSettingDict,
30+
GpuSplitConfig,
31+
GpuSplitConfigDict,
1332
KvConfig,
1433
KvConfigFieldDict,
1534
KvConfigStack,
@@ -18,6 +37,7 @@
1837
LlmLoadModelConfigDict,
1938
LlmPredictionConfig,
2039
LlmPredictionConfigDict,
40+
LlmSplitStrategy,
2141
LlmStructuredPredictionSetting,
2242
LlmStructuredPredictionSettingDict,
2343
)
@@ -54,7 +74,7 @@ def to_kv_field(
5474
def update_client_config(
5575
self, client_config: MutableDictObject, value: DictObject
5676
) -> None:
57-
if value.get("key", False):
77+
if value.get("checked", False):
5878
client_config[self.client_key] = value["value"]
5979

6080

@@ -84,26 +104,24 @@ def update_client_config(
84104
@dataclass(frozen=True)
85105
class MultiPartField(ConfigField):
86106
nested_keys: tuple[str, ...]
107+
client_to_server: Callable[..., Any]
108+
server_to_client: Callable[[DictObject, MutableDictObject], None]
87109

88110
def to_kv_field(
89111
self, server_key: str, client_config: DictObject
90112
) -> KvConfigFieldDict | None:
91-
containing_value = client_config[self.client_key]
92-
value: dict[str, Any] = {}
93-
for key in self.nested_keys:
94-
value[key] = containing_value[key]
113+
client_container: DictObject = client_config[self.client_key]
114+
values = (client_container.get(key, None) for key in self.nested_keys)
95115
return {
96116
"key": server_key,
97-
"value": value,
117+
"value": self.client_to_server(*values),
98118
}
99119

100120
def update_client_config(
101-
self, client_config: MutableDictObject, value: DictObject
121+
self, client_config: MutableDictObject, server_value: DictObject
102122
) -> None:
103-
containing_value = client_config.setdefault(self.client_key, {})
104-
for key in self.nested_keys:
105-
if key in value:
106-
containing_value[key] = value[key]
123+
client_container: MutableDictObject = client_config.setdefault(self.client_key, {})
124+
self.server_to_client(server_value, client_container)
107125

108126

109127
# TODO: figure out a way to compare this module against the lmstudio-js mappings
@@ -125,10 +143,68 @@ def update_client_config(
125143
"contextLength": ConfigField("contextLength"),
126144
}
127145

146+
147+
def _gpu_settings_to_gpu_split_config(
148+
main_gpu: int | None,
149+
llm_split_strategy: LlmSplitStrategy | None,
150+
disabledGpus: Sequence[int] | None,
151+
) -> GpuSplitConfigDict:
152+
gpu_split_config: GpuSplitConfigDict = {
153+
"disabledGpus": [*disabledGpus] if disabledGpus else [],
154+
"strategy": "evenly",
155+
"priority": [],
156+
"customRatio": [],
157+
}
158+
match llm_split_strategy:
159+
case "evenly" | None:
160+
pass
161+
case "favorMainGpu":
162+
gpu_split_config["strategy"] = "priorityOrder"
163+
if main_gpu is not None:
164+
gpu_split_config["priority"] = [main_gpu]
165+
case _:
166+
if TYPE_CHECKING:
167+
assert_never(llm_split_strategy)
168+
err_msg = f"Unknown LLM GPU offload split strategy: {llm_split_strategy}"
169+
hint = f"Known strategies: {get_args(LlmSplitStrategy)}"
170+
raise LMStudioValueError(f"{err_msg} ({hint})")
171+
return gpu_split_config
172+
173+
174+
def _gpu_split_config_to_gpu_settings(
175+
server_dict: DictObject, client_dict: MutableDictObject
176+
) -> None:
177+
gpu_settings_dict: GpuSettingDict = cast(GpuSettingDict, client_dict)
178+
gpu_split_config = GpuSplitConfig._from_any_api_dict(server_dict)
179+
disabled_gpus = gpu_split_config.disabled_gpus
180+
if disabled_gpus is not None:
181+
gpu_settings_dict["disabledGpus"] = disabled_gpus
182+
match gpu_split_config.strategy:
183+
case "evenly":
184+
gpu_settings_dict["splitStrategy"] = "evenly"
185+
case "priorityOrder":
186+
# For now, this can only map to "favorMainGpu"
187+
# Skip reporting the GPU offload details otherwise
188+
priority = gpu_split_config.priority
189+
if priority is not None and len(priority) == 1:
190+
gpu_settings_dict["splitStrategy"] = "favorMainGpu"
191+
gpu_settings_dict["mainGpu"] = priority[0]
192+
case "custom":
193+
# Currently no way to set up or report custom offload settings
194+
pass
195+
case _:
196+
if TYPE_CHECKING:
197+
assert_never(gpu_split_config.strategy)
198+
# Simply don't report details for unknown server strategies
199+
200+
128201
SUPPORTED_SERVER_KEYS: dict[str, DictObject] = {
129202
"load": {
130203
"gpuSplitConfig": MultiPartField(
131-
"gpu", ("mainGpu", "splitStrategy", "disabledGpus")
204+
"gpu",
205+
("mainGpu", "splitStrategy", "disabledGpus"),
206+
_gpu_settings_to_gpu_split_config,
207+
_gpu_split_config_to_gpu_settings,
132208
),
133209
"gpuStrictVramCap": ConfigField("gpuStrictVramCap"),
134210
},

tests/test_kv_config.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test translation from flat dict configs to KvConfig layer stacks."""
22

3-
from typing import Any
3+
from copy import deepcopy
4+
from typing import Any, Iterator, cast, get_args
45

56
import msgspec
67

@@ -14,17 +15,21 @@
1415
TO_SERVER_LOAD_LLM,
1516
TO_SERVER_PREDICTION,
1617
load_config_to_kv_config_stack,
18+
parse_server_config,
1719
prediction_config_to_kv_config_stack,
1820
)
1921
from lmstudio._sdk_models import (
2022
EmbeddingLoadModelConfig,
2123
EmbeddingLoadModelConfigDict,
2224
GpuSetting,
2325
GpuSettingDict,
26+
GpuSplitConfigDict,
27+
KvConfigStackDict,
2428
LlmLoadModelConfig,
2529
LlmLoadModelConfigDict,
2630
LlmPredictionConfig,
2731
LlmPredictionConfigDict,
32+
LlmSplitStrategy,
2833
)
2934

3035
# Note: configurations below are just for data manipulation round-trip testing,
@@ -262,7 +267,7 @@ def test_kv_stack_field_coverage(
262267
assert not unknown_keys
263268

264269

265-
EXPECTED_KV_STACK_LOAD_EMBEDDING = {
270+
EXPECTED_KV_STACK_LOAD_EMBEDDING: KvConfigStackDict = {
266271
"layers": [
267272
{
268273
"config": {
@@ -275,9 +280,10 @@ def test_kv_stack_field_coverage(
275280
{
276281
"key": "load.gpuSplitConfig",
277282
"value": {
278-
"mainGpu": 0,
279-
"splitStrategy": "evenly",
280283
"disabledGpus": [1, 2],
284+
"strategy": "evenly",
285+
"priority": [],
286+
"customRatio": [],
281287
},
282288
},
283289
{"key": "embedding.load.llama.keepModelInMemory", "value": True},
@@ -297,7 +303,7 @@ def test_kv_stack_field_coverage(
297303
],
298304
}
299305

300-
EXPECTED_KV_STACK_LOAD_LLM = {
306+
EXPECTED_KV_STACK_LOAD_LLM: KvConfigStackDict = {
301307
"layers": [
302308
{
303309
"layerName": "apiOverride",
@@ -308,9 +314,10 @@ def test_kv_stack_field_coverage(
308314
{
309315
"key": "load.gpuSplitConfig",
310316
"value": {
311-
"mainGpu": 0,
312-
"splitStrategy": "evenly",
313317
"disabledGpus": [1, 2],
318+
"strategy": "evenly",
319+
"priority": [],
320+
"customRatio": [],
314321
},
315322
},
316323
{"key": "llm.load.llama.evalBatchSize", "value": 42},
@@ -343,7 +350,7 @@ def test_kv_stack_field_coverage(
343350
]
344351
}
345352

346-
EXPECTED_KV_STACK_PREDICTION = {
353+
EXPECTED_KV_STACK_PREDICTION: KvConfigStackDict = {
347354
"layers": [
348355
{
349356
"config": {
@@ -438,6 +445,64 @@ def test_kv_stack_load_config_llm(config_dict: DictObject) -> None:
438445
assert kv_stack.to_dict() == EXPECTED_KV_STACK_LOAD_LLM
439446

440447

448+
def test_parse_server_config_load_embedding() -> None:
449+
server_config = EXPECTED_KV_STACK_LOAD_EMBEDDING["layers"][0]["config"]
450+
expected_client_config = deepcopy(LOAD_CONFIG_EMBEDDING)
451+
gpu_settings_dict = expected_client_config["gpu"]
452+
assert gpu_settings_dict is not None
453+
del gpu_settings_dict["mainGpu"] # This is not reported with "evenly" strategy
454+
assert parse_server_config(server_config) == expected_client_config
455+
456+
457+
def test_parse_server_config_load_llm() -> None:
458+
server_config = EXPECTED_KV_STACK_LOAD_LLM["layers"][0]["config"]
459+
expected_client_config = deepcopy(LOAD_CONFIG_LLM)
460+
gpu_settings_dict = expected_client_config["gpu"]
461+
assert gpu_settings_dict is not None
462+
del gpu_settings_dict["mainGpu"] # This is not reported with "evenly" strategy
463+
assert parse_server_config(server_config) == expected_client_config
464+
465+
466+
def _other_gpu_split_strategies() -> Iterator[LlmSplitStrategy]:
467+
# Ensure all GPU split strategies are checked (these aren't simple structural transforms,
468+
# so the default test case doesn't provide adequate test coverage )
469+
for split_strategy in get_args(LlmSplitStrategy):
470+
if split_strategy == GPU_CONFIG["splitStrategy"]:
471+
continue
472+
yield split_strategy
473+
474+
475+
def _find_config_field(stack_dict: KvConfigStackDict, key: str) -> Any:
476+
for field in stack_dict["layers"][0]["config"]["fields"]:
477+
if field["key"] == key:
478+
return field["value"]
479+
raise KeyError(key)
480+
481+
482+
@pytest.mark.parametrize("split_strategy", _other_gpu_split_strategies())
483+
def test_other_gpu_split_strategy_config(split_strategy: LlmSplitStrategy) -> None:
484+
expected_stack = deepcopy(EXPECTED_KV_STACK_LOAD_LLM)
485+
if split_strategy == "favorMainGpu":
486+
expected_split_config: GpuSplitConfigDict = _find_config_field(
487+
expected_stack, "load.gpuSplitConfig"
488+
)
489+
expected_split_config["strategy"] = "priorityOrder"
490+
main_gpu = GPU_CONFIG["mainGpu"]
491+
assert main_gpu is not None
492+
expected_split_config["priority"] = [main_gpu]
493+
else:
494+
assert split_strategy is None, "Unknown LLM GPU offset split strategy"
495+
input_camelCase = deepcopy(LOAD_CONFIG_LLM)
496+
input_snake_case = deepcopy(SC_LOAD_CONFIG_LLM)
497+
gpu_camelCase: GpuSettingDict = cast(Any, input_camelCase["gpu"])
498+
gpu_snake_case: dict[str, Any] = cast(Any, input_snake_case["gpu"])
499+
gpu_camelCase["splitStrategy"] = gpu_snake_case["split_strategy"] = split_strategy
500+
kv_stack = load_config_to_kv_config_stack(input_camelCase, LlmLoadModelConfig)
501+
assert kv_stack.to_dict() == expected_stack
502+
kv_stack = load_config_to_kv_config_stack(input_snake_case, LlmLoadModelConfig)
503+
assert kv_stack.to_dict() == expected_stack
504+
505+
441506
@pytest.mark.parametrize("config_dict", (PREDICTION_CONFIG, SC_PREDICTION_CONFIG))
442507
def test_kv_stack_prediction_config(config_dict: DictObject) -> None:
443508
# MyPy complains here that it can't be sure the dict has all the right keys

0 commit comments

Comments
 (0)