Skip to content

Commit b057090

Browse files
authored
Merge pull request #255 from optimas-org/enable_new_ax
Enable new ax
2 parents 7388649 + 9cd1fa0 commit b057090

File tree

11 files changed

+156
-50
lines changed

11 files changed

+156
-50
lines changed

.github/workflows/unix-openmpi.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
- shell: bash -l {0}
2929
name: Install dependencies
3030
run: |
31-
conda install numpy=1 pandas pytorch cpuonly -c pytorch
31+
conda install numpy pandas pytorch cpuonly -c pytorch
3232
conda install -c conda-forge mpi4py openmpi=5.*
3333
pip install .[test]
3434
- shell: bash -l {0}

.github/workflows/unix.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
- shell: bash -l {0}
2929
name: Install dependencies
3030
run: |
31-
conda install numpy=1 pandas pytorch cpuonly -c pytorch
31+
conda install numpy pandas pytorch cpuonly -c pytorch
3232
conda install -c conda-forge mpi4py mpich
3333
pip install .[test]
3434
- shell: bash -l {0}

doc/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ dependencies:
55
- pip
66
- pip:
77
- -e ..
8-
- ax-platform == 0.4.0
8+
- ax-platform >= 0.5.0
99
- autodoc_pydantic >= 2.0.1
1010
- ipykernel
1111
- matplotlib

optimas/generators/ax/developer/multitask.py

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import os
44
from copy import deepcopy
55
from typing import List, Dict, Tuple, Optional, Union
6+
from pyre_extensions import assert_is_instance
67

78
import numpy as np
89
import torch
9-
from packaging import version
1010

11-
from ax.version import version as ax_version
1211
from ax.core.arm import Arm
1312
from ax.core.batch_trial import BatchTrial
1413
from ax.core.multi_type_experiment import MultiTypeExperiment
@@ -22,8 +21,40 @@
2221
from ax.core.observation import ObservationFeatures
2322
from ax.core.generator_run import GeneratorRun
2423
from ax.storage.json_store.save import save_experiment
25-
from ax.storage.metric_registry import register_metric
26-
from ax.modelbridge.factory import get_MTGP_LEGACY as get_MTGP
24+
from ax.storage.metric_registry import register_metrics
25+
26+
from ax.modelbridge.registry import Models, ST_MTGP_trans
27+
28+
try:
29+
# For Ax >= 0.5.0
30+
from ax.modelbridge.transforms.derelativize import Derelativize
31+
from ax.modelbridge.transforms.convert_metric_names import (
32+
ConvertMetricNames,
33+
)
34+
from ax.modelbridge.transforms.trial_as_task import TrialAsTask
35+
from ax.modelbridge.transforms.stratified_standardize_y import (
36+
StratifiedStandardizeY,
37+
)
38+
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
39+
from ax.modelbridge.registry import MBM_X_trans
40+
41+
MT_MTGP_trans = MBM_X_trans + [
42+
Derelativize,
43+
ConvertMetricNames,
44+
TrialAsTask,
45+
StratifiedStandardizeY,
46+
TaskChoiceToIntTaskChoice,
47+
]
48+
49+
except ImportError:
50+
# For Ax < 0.5.0
51+
from ax.modelbridge.registry import MT_MTGP_trans
52+
53+
from ax.core.experiment import Experiment
54+
from ax.core.data import Data
55+
from ax.modelbridge.transforms.convert_metric_names import (
56+
tconfig_from_mt_experiment,
57+
)
2758

2859
from optimas.generators.ax.base import AxGenerator
2960
from optimas.core import (
@@ -37,13 +68,80 @@
3768
)
3869
from .ax_metric import AxMetric
3970

40-
4171
# Define generator states.
4272
NOT_STARTED = "not_started"
4373
LOFI_RETURNED = "lofi_returned"
4474
HIFI_RETURNED = "hifi_returned"
4575

4676

77+
# get_MTGP is not part of the Ax codebase, as of Ax 0.4.1, due to this PR:
78+
# https://github.com/facebook/Ax/pull/2508
79+
# Here we use `get_MTGP` https://ax.dev/docs/tutorials/multi_task/
80+
def get_MTGP(
81+
experiment: Experiment,
82+
data: Data,
83+
search_space: Optional[SearchSpace] = None,
84+
trial_index: Optional[int] = None,
85+
device: torch.device = torch.device("cpu"),
86+
dtype: torch.dtype = torch.double,
87+
) -> TorchModelBridge:
88+
"""Instantiate a Multi-task Gaussian Process (MTGP) model.
89+
90+
Points are generated with EI (Expected Improvement).
91+
If the input experiment is a MultiTypeExperiment then a
92+
Multi-type Multi-task GP model will be instantiated.
93+
Otherwise, the model will be a Single-type Multi-task GP.
94+
"""
95+
if isinstance(experiment, MultiTypeExperiment):
96+
trial_index_to_type = {
97+
t.index: t.trial_type for t in experiment.trials.values()
98+
}
99+
transforms = MT_MTGP_trans
100+
transform_configs = {
101+
"TrialAsTask": {
102+
"trial_level_map": {"trial_type": trial_index_to_type}
103+
},
104+
"ConvertMetricNames": tconfig_from_mt_experiment(experiment),
105+
}
106+
else:
107+
# Set transforms for a Single-type MTGP model.
108+
transforms = ST_MTGP_trans
109+
transform_configs = None
110+
111+
# Choose the status quo features for the experiment from the selected
112+
# trial. If trial_index is None, we will look for a status quo from the
113+
# last experiment trial to use as a status quo for the experiment.
114+
if trial_index is None:
115+
trial_index = len(experiment.trials) - 1
116+
elif trial_index >= len(experiment.trials):
117+
raise ValueError(
118+
"trial_index is bigger than the number of experiment trials"
119+
)
120+
121+
status_quo = experiment.trials[trial_index].status_quo
122+
if status_quo is None:
123+
status_quo_features = None
124+
else:
125+
status_quo_features = ObservationFeatures(
126+
parameters=status_quo.parameters,
127+
trial_index=trial_index, # pyre-ignore[6]
128+
)
129+
130+
return assert_is_instance(
131+
Models.ST_MTGP(
132+
experiment=experiment,
133+
search_space=search_space or experiment.search_space,
134+
data=data,
135+
transforms=transforms,
136+
transform_configs=transform_configs,
137+
torch_dtype=dtype,
138+
torch_device=device,
139+
status_quo_features=status_quo_features,
140+
),
141+
TorchModelBridge,
142+
)
143+
144+
47145
class AxMultitaskGenerator(AxGenerator):
48146
"""Multitask Bayesian optimization using the Ax developer API.
49147
@@ -307,7 +405,9 @@ def _create_experiment(self) -> MultiTypeExperiment:
307405
)
308406

309407
# Register metric in order to be able to save experiment to json file.
310-
_, encoder_registry, decoder_registry = register_metric(AxMetric)
408+
_, encoder_registry, decoder_registry = register_metrics(
409+
{AxMetric: None}
410+
)
311411
self._encoder_registry = encoder_registry
312412
self._decoder_registry = decoder_registry
313413

optimas/generators/ax/import_error_dummy_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ def __init__(self, *args, **kwargs) -> None:
1212
raise RuntimeError(
1313
"You need to install ax-platform, in order "
1414
"to use Ax-based generators in optimas.\n"
15-
"e.g. with `pip install ax-platform >= 0.4.0`"
15+
"e.g. with `pip install ax-platform > 0.5.0`"
1616
)

optimas/generators/ax/service/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _tell(self, trials: List[Trial]) -> None:
205205
# i.e., min trials).
206206
if isinstance(tc, (MinTrials, MaxTrials)):
207207
tc.threshold -= 1
208-
generation_strategy._maybe_move_to_next_step()
208+
generation_strategy._maybe_transition_to_next_node()
209209
finally:
210210
if trial.ignored:
211211
continue

optimas/generators/ax/service/multi_fidelity.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,13 @@ def _create_generation_steps(
124124
GenerationStep(
125125
model=Models.BOTORCH_MODULAR,
126126
num_trials=-1,
127-
model_kwargs=bo_model_kwargs,
127+
model_kwargs={
128+
**bo_model_kwargs,
129+
"acquisition_options": {
130+
"X_pending": None,
131+
"constraints": None,
132+
},
133+
},
128134
model_gen_kwargs={
129135
"model_gen_options": {
130136
Keys.ACQF_KWARGS: {

optimas/generators/ax/service/single_fidelity.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -134,23 +134,13 @@ def _create_generation_steps(
134134
) -> List[GenerationStep]:
135135
"""Create generation steps for single-fidelity optimization."""
136136
# Select BO model.
137+
# Ax 0.5.0 detects if multi-objective.
137138
if self._fully_bayesian:
138-
if len(self.objectives) > 1:
139-
# Use a SAAS model with qNEHVI acquisition function.
140-
MODEL_CLASS = Models.FULLYBAYESIANMOO
141-
else:
142-
# Use a SAAS model with qNEI acquisition function.
143-
MODEL_CLASS = Models.FULLYBAYESIAN
144-
# Disable additional logs from fully Bayesian model.
145-
bo_model_kwargs["disable_progbar"] = True
146-
bo_model_kwargs["verbose"] = False
139+
# Use a SAAS model with qNEHVI/qNEI acquisition function
140+
MODEL_CLASS = Models.SAASBO
147141
else:
148-
if len(self.objectives) > 1:
149-
# Use a model with qNEHVI acquisition function.
150-
MODEL_CLASS = Models.MOO
151-
else:
152-
# Use a model with qNEI acquisition function.
153-
MODEL_CLASS = Models.GPEI
142+
# Use a model with qNEHVI/qNEI acquisition function
143+
MODEL_CLASS = Models.BOTORCH_MODULAR
154144

155145
# Make generation strategy.
156146
steps = []

optimas/utils/ax/ax_model_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _build_ax_client_from_dataframe(
130130
# allow calling `model.predict`. Using MOO for multiobjective is
131131
# needed because otherwise calls to `get_pareto_optimal_parameters`
132132
# would fail.
133-
model = Models.GPEI if len(objectives) == 1 else Models.MOO
133+
model = Models.BOTORCH_MODULAR
134134
gs = GenerationStrategy([GenerationStep(model=model, num_trials=-1)])
135135
ax_client = AxClient(generation_strategy=gs, verbose_logging=False)
136136
ax_client.create_experiment(

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ test = [
3434
'flake8',
3535
'pytest',
3636
'pytest-mpi',
37-
'ax-platform == 0.4.0',
37+
'ax-platform >=0.5.0',
3838
'matplotlib',
3939
]
4040
all = [
41-
'ax-platform == 0.4.0',
41+
'ax-platform >=0.5.0',
4242
'matplotlib'
4343
]
4444

0 commit comments

Comments
 (0)