|
59 | 59 | from ._kv_config import ( |
60 | 60 | TLoadConfig, |
61 | 61 | TLoadConfigDict, |
62 | | - dict_from_fields_key, |
63 | 62 | load_config_to_kv_config_stack, |
| 63 | + parse_llm_load_config, |
| 64 | + parse_prediction_config, |
64 | 65 | prediction_config_to_kv_config_stack, |
65 | 66 | ) |
66 | 67 | from ._sdk_models import ( |
|
128 | 129 | # implicitly as part of the top-level `lmstudio` API. |
129 | 130 | __all__ = [ |
130 | 131 | "ActResult", |
| 132 | + "AnyLoadConfig", |
131 | 133 | "AnyModelSpecifier", |
132 | 134 | "DownloadFinalizedCallback", |
133 | 135 | "DownloadProgressCallback", |
|
180 | 182 | DEFAULT_TTL = 60 * 60 # By default, leaves idle models loaded for an hour |
181 | 183 |
|
182 | 184 | AnyModelSpecifier: TypeAlias = str | ModelSpecifier | ModelQuery | DictObject |
| 185 | +AnyLoadConfig: TypeAlias = EmbeddingLoadModelConfig | LlmLoadModelConfig |
183 | 186 |
|
184 | 187 | GetOrLoadChannelRequest: TypeAlias = ( |
185 | 188 | EmbeddingChannelGetOrLoadCreationParameter | LlmChannelGetOrLoadCreationParameter |
@@ -441,12 +444,9 @@ class PredictionResult(Generic[TPrediction]): |
441 | 444 | parsed: TPrediction # dict for structured predictions, str otherwise |
442 | 445 | stats: LlmPredictionStats # Statistics about the prediction process |
443 | 446 | model_info: LlmInfo # Information about the model used |
444 | | - structured: bool = field(init=False) # Whether the result is structured or not |
445 | | - # Note that the configuration reported here uses the *server* config names, |
446 | | - # not the attributes used to set the configuration in the client SDK |
447 | | - # Private until these attributes store the client config types |
448 | | - _load_config: DictObject # The configuration used to load the model |
449 | | - _prediction_config: DictObject # The configuration used for the prediction |
| 447 | + structured: bool = field(init=False) # Whether the result is structured or not |
| 448 | + load_config: LlmLoadModelConfig # The configuration used to load the model |
| 449 | + prediction_config: LlmPredictionConfig # The configuration used for the prediction |
450 | 450 | # fmt: on |
451 | 451 |
|
452 | 452 | def __post_init__(self) -> None: |
@@ -1262,8 +1262,8 @@ def iter_message_events( |
1262 | 1262 | parsed=parsed_content, |
1263 | 1263 | stats=LlmPredictionStats._from_any_api_dict(stats), |
1264 | 1264 | model_info=LlmInfo._from_any_api_dict(model_info), |
1265 | | - _load_config=dict_from_fields_key(load_kvconfig), |
1266 | | - _prediction_config=dict_from_fields_key(prediction_kvconfig), |
| 1265 | + load_config=parse_llm_load_config(load_kvconfig), |
| 1266 | + prediction_config=parse_prediction_config(prediction_kvconfig), |
1267 | 1267 | ) |
1268 | 1268 | ) |
1269 | 1269 | case unmatched: |
@@ -1477,19 +1477,19 @@ def model_info(self) -> LlmInfo | None: |
1477 | 1477 |
|
1478 | 1478 | # Private until this API can emit the client config types |
1479 | 1479 | @property |
1480 | | - def _load_config(self) -> DictObject | None: |
| 1480 | + def _load_config(self) -> LlmLoadModelConfig | None: |
1481 | 1481 | """Get the load configuration used for the current prediction if available.""" |
1482 | 1482 | if self._final_result is None: |
1483 | 1483 | return None |
1484 | | - return self._final_result._load_config |
| 1484 | + return self._final_result.load_config |
1485 | 1485 |
|
1486 | 1486 | # Private until this API can emit the client config types |
1487 | 1487 | @property |
1488 | | - def _prediction_config(self) -> DictObject | None: |
| 1488 | + def _prediction_config(self) -> LlmPredictionConfig | None: |
1489 | 1489 | """Get the prediction configuration used for the current prediction if available.""" |
1490 | 1490 | if self._final_result is None: |
1491 | 1491 | return None |
1492 | | - return self._final_result._prediction_config |
| 1492 | + return self._final_result.prediction_config |
1493 | 1493 |
|
1494 | 1494 | @sdk_public_api() |
1495 | 1495 | def result(self) -> PredictionResult[TPrediction]: |
|
0 commit comments