Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.llm_provider import LLMMessage
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
Expand Down Expand Up @@ -222,6 +223,10 @@ def _init_experiment_from_sqa(
# `experiment_sqa.properties` is `sqlalchemy.ext.mutable.MutableDict`
# so need to convert it to regular dict.
properties = dict(experiment_sqa.properties or {})
if Keys.LLM_MESSAGES in properties:
properties[Keys.LLM_MESSAGES] = [
LLMMessage(**m) for m in properties[Keys.LLM_MESSAGES]
]
pruning_target = (
self._get_pruning_target_parameterization_from_experiment_properties(
properties=properties
Expand Down Expand Up @@ -286,6 +291,10 @@ def _init_mt_experiment_from_sqa(
) -> MultiTypeExperiment:
"""First step of conversion within experiment_from_sqa."""
properties = dict(experiment_sqa.properties or {})
if Keys.LLM_MESSAGES in properties:
properties[Keys.LLM_MESSAGES] = [
LLMMessage(**m) for m in properties[Keys.LLM_MESSAGES]
]
pruning_target = (
self._get_pruning_target_parameterization_from_experiment_properties(
properties=properties
Expand Down
7 changes: 7 additions & 0 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import dataclasses
from enum import Enum
from logging import Logger
from typing import Any, cast
Expand All @@ -29,6 +30,7 @@
from ax.core.evaluations_to_data import DataType
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.llm_provider import LLMMessage
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
Expand Down Expand Up @@ -239,6 +241,11 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
properties["pruning_target_parameterization"] = arm_to_dict(
oc.pruning_target_parameterization
)
if Keys.LLM_MESSAGES in properties:
properties[Keys.LLM_MESSAGES] = [
dataclasses.asdict(m) if isinstance(m, LLMMessage) else m
for m in properties[Keys.LLM_MESSAGES]
]

# pyre-ignore[9]: Expected `Base` for 1st...yping.Type[Experiment]`.
experiment_class: type[SQAExperiment] = self.config.class_to_sqa_class[
Expand Down
Loading