6
6
from ax .core .objective import MultiObjective
7
7
8
8
from optimas .core import Objective , VaryingParameter , Parameter
9
+ from generator_standard .vocs import VOCS
9
10
from .base import AxServiceGenerator
10
11
11
12
@@ -71,18 +72,19 @@ def __init__(
71
72
model_save_period : Optional [int ] = 5 ,
72
73
model_history_dir : Optional [str ] = "model_history" ,
73
74
):
74
- varying_parameters = self ._get_varying_parameters (ax_client )
75
- objectives = self ._get_objectives (ax_client )
75
+ # Create VOCS object from AxClient data
76
+ vocs = self ._create_vocs_from_ax_client (ax_client )
77
+
78
+ # Add constraints to analyzed parameters
76
79
analyzed_parameters = self ._add_constraints_to_analyzed_parameters (
77
80
analyzed_parameters , ax_client
78
81
)
82
+
79
83
use_cuda = self ._use_cuda (ax_client )
80
84
self ._ax_client = ax_client
85
+
81
86
super ().__init__ (
82
- varying_parameters = varying_parameters ,
83
- objectives = objectives ,
84
- analyzed_parameters = analyzed_parameters ,
85
- enforce_n_init = True ,
87
+ vocs = vocs ,
86
88
abandon_failed_trials = abandon_failed_trials ,
87
89
use_cuda = use_cuda ,
88
90
gpu_id = gpu_id ,
@@ -92,6 +94,38 @@ def __init__(
92
94
model_history_dir = model_history_dir ,
93
95
)
94
96
97
+ def _create_vocs_from_ax_client (self , ax_client : AxClient ) -> VOCS :
98
+ """Create a VOCS object from the AxClient data."""
99
+ # Extract variables from search space
100
+ variables = {}
101
+ for _ , p in ax_client .experiment .search_space .parameters .items ():
102
+ variables [p .name ] = [p .lower , p .upper ]
103
+
104
+ # Extract objectives from optimization config
105
+ objectives = {}
106
+ ax_objective = ax_client .experiment .optimization_config .objective
107
+ if isinstance (ax_objective , MultiObjective ):
108
+ ax_objectives = ax_objective .objectives
109
+ else :
110
+ ax_objectives = [ax_objective ]
111
+
112
+ for ax_obj in ax_objectives :
113
+ obj_type = "MINIMIZE" if ax_obj .minimize else "MAXIMIZE"
114
+ objectives [ax_obj .metric_names [0 ]] = obj_type
115
+
116
+ # Extract observables from outcome constraints (if any)
117
+ observables = set ()
118
+ ax_config = ax_client .experiment .optimization_config
119
+ if ax_config .outcome_constraints :
120
+ for constraint in ax_config .outcome_constraints :
121
+ observables .add (constraint .metric .name )
122
+
123
+ return VOCS (
124
+ variables = variables ,
125
+ objectives = objectives ,
126
+ observables = observables ,
127
+ )
128
+
95
129
def _get_varying_parameters (self , ax_client : AxClient ):
96
130
"""Obtain the list of varying parameters from the AxClient."""
97
131
varying_parameters = []
0 commit comments