Skip to content

Commit 8121938

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Re-index TorchOptConfig.objective_thresholds from (n_outcomes,) to (n_objectives,)
Summary: Re-index `objective_thresholds` in `TorchOptConfig` from `(n_outcomes,)` with NaN for non-objective outcomes to `(n_objectives,)` or `None`. This is a prerequisite for supporting `ScalarizedObjective` as a sub-objective of `MultiObjective`. Key changes: - `extract_objective_thresholds` returns `(n_objectives,)` maximization-aligned array (sign-flipped for minimize objectives), or `None` for single-objective. - `_untransform_objective_thresholds` indexes by objective index and un-flips the sign when converting back to raw `ObjectiveThreshold.bound`. - Replaced `get_weighted_mc_objective_and_objective_thresholds` with `get_weighted_mc_objective`, which only returns the objective (thresholds no longer need transformation so callers use them directly). - `infer_objective_thresholds` returns `(n_objectives,)` maximization-aligned. - Removed `objective_thresholds` from `SubsetModelData` and `subset_model` (thresholds are per-objective, not per-outcome, so subsetting doesn't apply). - Simplified `_objective_threshold_to_outcome_constraints` and pruning logic. - Merged `_full_objective_thresholds` and `_objective_thresholds` into single `_objective_thresholds` in the `Acquisition` class. - Pass thresholds directly as `ref_point` to BoTorch input constructors (using the new `ref_point` parameter from D96473523), avoiding the need to convert back to outcome space. Differential Revision: D96391935
1 parent 1e48d0f commit 8121938

File tree

15 files changed

+203
-285
lines changed

15 files changed

+203
-285
lines changed

ax/adapter/adapter_utils.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ax.core.types import TBounds, TCandidateMetadata, TNumeric
4242
from ax.exceptions.core import DataRequiredError, UserInputError
4343
from ax.generators.torch.botorch_moo_utils import (
44-
get_weighted_mc_objective_and_objective_thresholds,
44+
get_weighted_mc_objective,
4545
pareto_frontier_evaluator,
4646
)
4747
from ax.utils.common.logger import get_logger
@@ -205,23 +205,26 @@ def extract_objective_thresholds(
205205
objective: Objective,
206206
outcomes: list[str],
207207
) -> npt.NDArray | None:
208-
"""Extracts objective thresholds' values, in the order of `outcomes`.
208+
"""Extracts objective thresholds' values, in the order of objectives.
209209
210-
Will return None if no objective thresholds, otherwise the extracted tensor
211-
will be the same length as `outcomes`.
210+
Will return None if no objective thresholds or if the objective is single-
211+
objective. Otherwise the extracted array will have length ``n_objectives``
212+
(matching the rows of the objective weight matrix).
212213
213-
Outcomes that are not part of an objective and the objectives that do no have
214-
a corresponding objective threshold will be given a threshold of NaN. We will
215-
later infer appropriate threshold values for the objectives that are given a
216-
threshold of NaN.
214+
Objectives that do not have a corresponding objective threshold will be
215+
given a threshold of NaN. We will later infer appropriate threshold values
216+
for those objectives.
217+
218+
The returned thresholds are maximization-aligned: for minimize objectives,
219+
the threshold is negated.
217220
218221
Args:
219222
objective_thresholds: Objective thresholds to extract values from.
220223
objective: The corresponding Objective, for validation purposes.
221224
outcomes: n-length list of names of metrics.
222225
223226
Returns:
224-
(n,) array of thresholds
227+
``(n_objectives,)`` array of maximization-aligned thresholds, or None.
225228
"""
226229
if len(objective_thresholds) == 0:
227230
return None
@@ -242,11 +245,19 @@ def extract_objective_thresholds(
242245
f"Got {objective_thresholds=} and {objective=}."
243246
)
244247

245-
# Initialize these to be NaN to make sure that objective thresholds for
246-
# non-objective metrics are never used.
247-
obj_t = np.full(len(outcomes), float("nan"))
248-
for metric, threshold in objective_threshold_dict.items():
249-
obj_t[outcomes.index(metric)] = threshold
248+
if not isinstance(objective, MultiObjective):
249+
# Single objective — thresholds not applicable.
250+
return None
251+
252+
n_objectives = len(objective.objectives)
253+
obj_t = np.full(n_objectives, float("nan"))
254+
for i, sub_obj in enumerate(objective.objectives):
255+
if isinstance(sub_obj, ScalarizedObjective):
256+
continue # NaN — will be inferred later.
257+
metric_name = sub_obj.metric.signature
258+
if metric_name in objective_threshold_dict:
259+
sign = -1.0 if sub_obj.minimize else 1.0
260+
obj_t[i] = sign * objective_threshold_dict[metric_name]
250261
return obj_t
251262

252263

@@ -743,10 +754,8 @@ def pareto_frontier(
743754
if obj_t is None:
744755
return frontier_observations
745756

746-
# Apply appropriate weights and thresholds
747-
obj, obj_t = get_weighted_mc_objective_and_objective_thresholds(
748-
objective_weights=obj_w, objective_thresholds=obj_t
749-
)
757+
# Apply appropriate weights
758+
obj = get_weighted_mc_objective(objective_weights=obj_w)
750759
f_t = obj(f)
751760

752761
# Compute individual hypervolumes by taking the difference between the observation
@@ -911,10 +920,9 @@ def hypervolume(
911920
dtype=torch.bool,
912921
device=f.device,
913922
)
914-
# Apply appropriate weights and thresholds
915-
obj, obj_t = get_weighted_mc_objective_and_objective_thresholds(
916-
objective_weights=obj_w, objective_thresholds=none_throws(obj_t)
917-
)
923+
# Apply appropriate weights
924+
obj = get_weighted_mc_objective(objective_weights=obj_w)
925+
obj_t = none_throws(obj_t)
918926
f_t = obj(f)
919927
obj_mask = (obj_w != 0).any(dim=0).nonzero().view(-1)
920928
selected_metrics_mask = selected_metrics_mask[obj_mask]

ax/adapter/tests/test_torch_moo_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ def helper_test_pareto_frontier(
191191
)
192192
)
193193
self.assertTrue(obj_t is not None)
194+
# Thresholds are now (n_objectives,) and maximization-aligned.
195+
# LEQ thresholds with bound=5.0 become -5.0 after sign flip.
194196
self.assertTrue(
195-
torch.equal(
196-
none_throws(obj_t)[:2], torch.full((2,), 5.0, dtype=torch.double)
197-
)
197+
torch.equal(none_throws(obj_t), torch.full((2,), -5.0, dtype=torch.double))
198198
)
199199
observed_frontier2 = pareto_frontier(
200200
adapter=adapter,

ax/adapter/tests/test_utils.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,32 +122,33 @@ def test_extract_objective_thresholds(self) -> None:
122122
for i, name in enumerate(outcomes[:3])
123123
]
124124

125-
# None of no thresholds
125+
# None if no thresholds
126126
self.assertIsNone(
127127
extract_objective_thresholds(
128128
objective_thresholds=[], objective=objective, outcomes=outcomes
129129
)
130130
)
131131

132-
# Working case
132+
# Working case: 3 objectives (all maximize), shape is (3,)
133133
obj_t = extract_objective_thresholds(
134134
objective_thresholds=objective_thresholds,
135135
objective=objective,
136136
outcomes=outcomes,
137137
)
138-
expected_obj_t_not_nan = np.array([2.0, 3.0, 4.0])
139-
self.assertTrue(np.array_equal(obj_t[:3], expected_obj_t_not_nan[:3]))
140-
self.assertTrue(np.isnan(obj_t[-1]))
141-
self.assertEqual(obj_t.shape[0], 4)
138+
# All maximize, so thresholds are unchanged (sign = +1).
139+
expected_obj_t = np.array([2.0, 3.0, 4.0])
140+
self.assertTrue(np.array_equal(obj_t, expected_obj_t))
141+
self.assertEqual(obj_t.shape[0], 3)
142142

143143
# Returns NaN for objectives without a threshold.
144144
obj_t = extract_objective_thresholds(
145145
objective_thresholds=objective_thresholds[:2],
146146
objective=objective,
147147
outcomes=outcomes,
148148
)
149-
self.assertTrue(np.array_equal(obj_t[:2], expected_obj_t_not_nan[:2]))
150-
self.assertTrue(np.isnan(obj_t[-2:]).all())
149+
self.assertTrue(np.array_equal(obj_t[:2], expected_obj_t[:2]))
150+
self.assertTrue(np.isnan(obj_t[2]))
151+
self.assertEqual(obj_t.shape[0], 3)
151152

152153
# Fails if a threshold does not have a corresponding metric.
153154
objective2 = Objective(Metric("m1"), minimize=False)
@@ -158,15 +159,46 @@ def test_extract_objective_thresholds(self) -> None:
158159
outcomes=outcomes,
159160
)
160161

161-
# Works with a single objective, single threshold
162+
# Single objective returns None.
163+
self.assertIsNone(
164+
extract_objective_thresholds(
165+
objective_thresholds=objective_thresholds[:1],
166+
objective=objective2,
167+
outcomes=outcomes,
168+
)
169+
)
170+
171+
# Maximize-alignment: minimize objectives get negated thresholds.
172+
objective_with_min = MultiObjective(
173+
objectives=[
174+
Objective(metric=Metric("m1"), minimize=False),
175+
Objective(metric=Metric("m2"), minimize=True),
176+
]
177+
)
178+
obj_thresholds_for_min = [
179+
ObjectiveThreshold(
180+
metric=Metric("m1"),
181+
op=ComparisonOp.LEQ,
182+
bound=2.0,
183+
relative=False,
184+
),
185+
ObjectiveThreshold(
186+
metric=Metric("m2"),
187+
op=ComparisonOp.LEQ,
188+
bound=3.0,
189+
relative=False,
190+
),
191+
]
162192
obj_t = extract_objective_thresholds(
163-
objective_thresholds=objective_thresholds[:1],
164-
objective=objective2,
193+
objective_thresholds=obj_thresholds_for_min,
194+
objective=objective_with_min,
165195
outcomes=outcomes,
166196
)
197+
# m1 maximize: sign=+1, threshold=2.0 → 2.0
198+
# m2 minimize: sign=-1, threshold=3.0 → -3.0
199+
self.assertEqual(obj_t.shape[0], 2)
167200
self.assertEqual(obj_t[0], 2.0)
168-
self.assertTrue(np.all(np.isnan(obj_t[1:])))
169-
self.assertEqual(obj_t.shape[0], 4)
201+
self.assertEqual(obj_t[1], -3.0)
170202

171203
# Fails if relative
172204
objective_thresholds[2] = ObjectiveThreshold(

ax/adapter/torch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,12 +1111,14 @@ def _untransform_objective_thresholds(
11111111
"""
11121112
obj_indices, obj_weights = extract_objectives(objective_weights)
11131113
thresholds = []
1114-
for idx, w in zip(obj_indices, obj_weights):
1114+
for i, (idx, w) in enumerate(zip(obj_indices, obj_weights)):
11151115
sign = torch.sign(w)
1116+
# Thresholds are maximization-aligned; undo sign flip to get raw bound.
1117+
raw_bound = float(sign * objective_thresholds[i].item())
11161118
thresholds.append(
11171119
ObjectiveThreshold(
11181120
metric=opt_config_metrics[self.outcomes[idx]],
1119-
bound=float(objective_thresholds[idx].item()),
1121+
bound=raw_bound,
11201122
relative=False,
11211123
op=ComparisonOp.LEQ if sign < 0 else ComparisonOp.GEQ,
11221124
)

ax/generators/tests/test_botorch_moo_utils.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from unittest import mock
1212
from warnings import catch_warnings, simplefilter
1313

14-
import numpy as np
1514
import torch
1615
from ax.core.search_space import SearchSpaceDigest
1716
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
1817
from ax.generators.torch.botorch_moo_utils import (
19-
get_weighted_mc_objective_and_objective_thresholds,
18+
get_weighted_mc_objective,
2019
infer_objective_thresholds,
2120
pareto_frontier_evaluator,
2221
)
@@ -213,19 +212,13 @@ def test_pareto_frontier_evaluator_with_nan(self) -> None:
213212

214213

215214
class BotorchMOOUtilsTest(TestCase):
216-
def test_get_weighted_mc_objective_and_objective_thresholds(self) -> None:
215+
def test_get_weighted_mc_objective(self) -> None:
217216
objective_weights = torch.tensor([[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]])
218-
objective_thresholds = torch.arange(4, dtype=torch.float)
219-
(
220-
weighted_obj,
221-
new_obj_thresholds,
222-
) = get_weighted_mc_objective_and_objective_thresholds(
217+
weighted_obj = get_weighted_mc_objective(
223218
objective_weights=objective_weights,
224-
objective_thresholds=objective_thresholds,
225219
)
226220
self.assertTrue(torch.equal(weighted_obj.weights, torch.tensor([1.0, 1.0])))
227221
self.assertEqual(weighted_obj.outcomes.tolist(), [1, 3])
228-
self.assertTrue(torch.equal(new_obj_thresholds, objective_thresholds[[1, 3]]))
229222

230223
# test infer objective thresholds alone
231224
@mock.patch( # pyre-ignore
@@ -255,6 +248,10 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
255248
objective_weights = torch.tensor(
256249
[[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], **tkwargs
257250
)
251+
# Expected: infer_reference_point returns (n_objectives,) in
252+
# maximization-aligned space. With pareto_Y=[[-9, -3]] and
253+
# scale=0.1, the result is [-9.9, -3.3].
254+
expected_thresholds = torch.tensor([-9.9, -3.3], **tkwargs)
258255
with ExitStack() as es:
259256
_mock_infer_reference_point = es.enter_context(
260257
mock.patch(
@@ -282,10 +279,9 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
282279
torch.tensor([[-9.0, -3.0]], **tkwargs),
283280
)
284281
)
285-
self.assertTrue(
286-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
287-
)
288-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
282+
# Result is (n_objectives,) maximization-aligned.
283+
self.assertEqual(obj_thresholds.shape[0], 2)
284+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
289285

290286
# test subset_model without subset_idcs
291287
with mock.patch.object(model, "posterior", return_value=posterior):
@@ -295,10 +291,8 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
295291
outcome_constraints=outcome_constraints,
296292
X_observed=Xs[0],
297293
)
298-
self.assertTrue(
299-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
300-
)
301-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
294+
self.assertEqual(obj_thresholds.shape[0], 2)
295+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
302296
# test passing subset_idcs
303297
subset_idcs = torch.tensor(
304298
[0, 1], dtype=torch.long, device=tkwargs["device"]
@@ -312,10 +306,8 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
312306
X_observed=Xs[0],
313307
subset_idcs=subset_idcs,
314308
)
315-
self.assertTrue(
316-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
317-
)
318-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
309+
self.assertEqual(obj_thresholds.shape[0], 2)
310+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
319311
# test without subsetting (e.g. if there are
320312
# 3 metrics for 2 objectives + 1 outcome constraint)
321313
outcome_constraints = (
@@ -350,10 +342,8 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
350342
X_observed=Xs[0],
351343
outcome_constraints=outcome_constraints,
352344
)
353-
self.assertTrue(
354-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
355-
)
356-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
345+
self.assertEqual(obj_thresholds.shape[0], 2)
346+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
357347

358348
def test_infer_objective_thresholds_cuda(self) -> None:
359349
if torch.cuda.is_available():

0 commit comments

Comments
 (0)