Skip to content

Commit 6f0933f

Browse files
committed
Create vocs in ax_client
1 parent bbc23bc commit 6f0933f

File tree

1 file changed

+40
-6
lines changed

1 file changed

+40
-6
lines changed

optimas/generators/ax/service/ax_client.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ax.core.objective import MultiObjective
77

88
from optimas.core import Objective, VaryingParameter, Parameter
9+
from generator_standard.vocs import VOCS
910
from .base import AxServiceGenerator
1011

1112

@@ -71,18 +72,19 @@ def __init__(
7172
model_save_period: Optional[int] = 5,
7273
model_history_dir: Optional[str] = "model_history",
7374
):
74-
varying_parameters = self._get_varying_parameters(ax_client)
75-
objectives = self._get_objectives(ax_client)
75+
# Create VOCS object from AxClient data
76+
vocs = self._create_vocs_from_ax_client(ax_client)
77+
78+
# Add constraints to analyzed parameters
7679
analyzed_parameters = self._add_constraints_to_analyzed_parameters(
7780
analyzed_parameters, ax_client
7881
)
82+
7983
use_cuda = self._use_cuda(ax_client)
8084
self._ax_client = ax_client
85+
8186
super().__init__(
82-
varying_parameters=varying_parameters,
83-
objectives=objectives,
84-
analyzed_parameters=analyzed_parameters,
85-
enforce_n_init=True,
87+
vocs=vocs,
8688
abandon_failed_trials=abandon_failed_trials,
8789
use_cuda=use_cuda,
8890
gpu_id=gpu_id,
@@ -92,6 +94,38 @@ def __init__(
9294
model_history_dir=model_history_dir,
9395
)
9496

97+
def _create_vocs_from_ax_client(self, ax_client: AxClient) -> VOCS:
98+
"""Create a VOCS object from the AxClient data."""
99+
# Extract variables from search space
100+
variables = {}
101+
for _, p in ax_client.experiment.search_space.parameters.items():
102+
variables[p.name] = [p.lower, p.upper]
103+
104+
# Extract objectives from optimization config
105+
objectives = {}
106+
ax_objective = ax_client.experiment.optimization_config.objective
107+
if isinstance(ax_objective, MultiObjective):
108+
ax_objectives = ax_objective.objectives
109+
else:
110+
ax_objectives = [ax_objective]
111+
112+
for ax_obj in ax_objectives:
113+
obj_type = "MINIMIZE" if ax_obj.minimize else "MAXIMIZE"
114+
objectives[ax_obj.metric_names[0]] = obj_type
115+
116+
# Extract observables from outcome constraints (if any)
117+
observables = set()
118+
ax_config = ax_client.experiment.optimization_config
119+
if ax_config.outcome_constraints:
120+
for constraint in ax_config.outcome_constraints:
121+
observables.add(constraint.metric.name)
122+
123+
return VOCS(
124+
variables=variables,
125+
objectives=objectives,
126+
observables=observables,
127+
)
128+
95129
def _get_varying_parameters(self, ax_client: AxClient):
96130
"""Obtain the list of varying parameters from the AxClient."""
97131
varying_parameters = []

0 commit comments

Comments
 (0)