Skip to content

Commit fe2089f

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Handle LLMMessage serialization at the storage layer (#5024)
Summary: Move `LLMMessage` dict conversion from the `experiment.llm_messages` getter/setter to the storage encoders/decoders, following Ax convention that domain objects hold domain types and serialization happens at the storage boundary. **`experiment.py`**: The setter now stores `LLMMessage` objects directly in `_properties`. The getter handles both `LLMMessage` objects (new path) and plain dicts (backward compat with previously stored data). **JSON store**: No explicit changes needed — the encoder's generic dataclass fallback auto-serializes `LLMMessage` with a `__type` tag, and `LLMMessage` is already registered in `CORE_DECODER_REGISTRY`. **SQA store**: The encoder converts `LLMMessage` → dict via `dataclasses.asdict()` in the properties copy before DB write (same pattern as `pruning_target_parameterization`). The decoder converts dicts → `LLMMessage` after loading properties, in both `_init_experiment_from_sqa` and `_init_mt_experiment_from_sqa`. Reviewed By: lena-kashtelyan Differential Revision: D96434290
1 parent a6de70a commit fe2089f

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

ax/storage/sqa_store/decoder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ax.core.data import Data
3232
from ax.core.experiment import Experiment
3333
from ax.core.generator_run import GeneratorRun
34+
from ax.core.llm_provider import LLMMessage
3435
from ax.core.metric import Metric
3536
from ax.core.multi_type_experiment import MultiTypeExperiment
3637
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
@@ -220,6 +221,10 @@ def _init_experiment_from_sqa(
220221
# `experiment_sqa.properties` is `sqlalchemy.ext.mutable.MutableDict`
221222
# so need to convert it to regular dict.
222223
properties = dict(experiment_sqa.properties or {})
224+
if Keys.LLM_MESSAGES in properties:
225+
properties[Keys.LLM_MESSAGES] = [
226+
LLMMessage(**m) for m in properties[Keys.LLM_MESSAGES]
227+
]
223228
opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa(
224229
metrics_sqa=experiment_sqa.metrics,
225230
pruning_target_parameterization=(
@@ -281,6 +286,10 @@ def _init_mt_experiment_from_sqa(
281286
) -> MultiTypeExperiment:
282287
"""First step of conversion within experiment_from_sqa."""
283288
properties = dict(experiment_sqa.properties or {})
289+
if Keys.LLM_MESSAGES in properties:
290+
properties[Keys.LLM_MESSAGES] = [
291+
LLMMessage(**m) for m in properties[Keys.LLM_MESSAGES]
292+
]
284293
opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa(
285294
metrics_sqa=experiment_sqa.metrics,
286295
pruning_target_parameterization=(

ax/storage/sqa_store/encoder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import dataclasses
910
from enum import Enum
1011
from logging import Logger
1112
from typing import Any, cast
@@ -28,6 +29,7 @@
2829
from ax.core.evaluations_to_data import DataType
2930
from ax.core.experiment import Experiment
3031
from ax.core.generator_run import GeneratorRun
32+
from ax.core.llm_provider import LLMMessage
3133
from ax.core.metric import Metric
3234
from ax.core.multi_type_experiment import MultiTypeExperiment
3335
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
@@ -234,6 +236,11 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
234236
properties["pruning_target_parameterization"] = arm_to_dict(
235237
oc.pruning_target_parameterization
236238
)
239+
if Keys.LLM_MESSAGES in properties:
240+
properties[Keys.LLM_MESSAGES] = [
241+
dataclasses.asdict(m) if isinstance(m, LLMMessage) else m
242+
for m in properties[Keys.LLM_MESSAGES]
243+
]
237244

238245
# pyre-ignore[9]: Expected `Base` for 1st...yping.Type[Experiment]`.
239246
experiment_class: type[SQAExperiment] = self.config.class_to_sqa_class[

0 commit comments

Comments
 (0)