Skip to content

Commit 6d8c782

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Re-index TorchOptConfig.objective_thresholds from (n_outcomes,) to (n_objectives,)
Summary: Re-index `objective_thresholds` in `TorchOptConfig` from `(n_outcomes,)` with NaN for non-objective outcomes to `(n_objectives,)` or `None`. This is a prerequisite for supporting `ScalarizedObjective` as a sub-objective of `MultiObjective`. Key changes: - `extract_objective_thresholds` returns `(n_objectives,)` maximization-aligned array (sign-flipped for minimize objectives), or `None` for single-objective. Rewritten to use expression-based API (`parse_objective_expression` / `extract_metric_weights_from_objective_expr`) instead of deprecated `MultiObjective.objectives` / `ScalarizedObjective` class checks. - `_untransform_objective_thresholds` indexes by objective index and un-flips the sign when converting back to raw `ObjectiveThreshold.bound`. - Replaced `get_weighted_mc_objective_and_objective_thresholds` with `get_weighted_mc_objective`, which only returns the objective (thresholds no longer need transformation so callers use them directly). - `infer_objective_thresholds` returns `(n_objectives,)` maximization-aligned. - Removed `objective_thresholds` from `SubsetModelData` and `subset_model` (thresholds are per-objective, not per-outcome, so subsetting doesn't apply). - Simplified `_objective_threshold_to_outcome_constraints` and pruning logic. - Merged `_full_objective_thresholds` and `_objective_thresholds` into single `_objective_thresholds` in the `Acquisition` class. - Pass thresholds directly as `ref_point` to BoTorch input constructors (using the new `ref_point` parameter from D96473523), avoiding the need to convert back to outcome space. - Fixed Pyre type errors in `test_acquisition.py` for `Optional[Tensor]` return from `Acquisition.objective_thresholds`. TRBO isolation (axoptics): - TRBO uses its own `objective_weights` and `max_reference_point` from constructor kwargs -- it does NOT use `TorchOptConfig` from `gen()`. Its `__init__` expects legacy format: `(n_outcomes,)` with raw bounds and NaN for non-objectives, then handles maximization alignment internally. - Replaced the `extract_objective_thresholds` call in axoptics `_mk_TRBO_generation_strategy` with inline legacy logic that produces the `(n_outcomes,)` raw-bounds format TRBO expects, keeping TRBO isolated from the new `(n_objectives,)` format. - Added comments to `trbo.py` documenting the legacy input contract and future refactoring TODO. Differential Revision: D96391935
1 parent bbea5f9 commit 6d8c782

File tree

15 files changed

+224
-290
lines changed

15 files changed

+224
-290
lines changed

ax/adapter/adapter_utils.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,16 @@
4141
from ax.core.types import TBounds, TCandidateMetadata, TNumeric
4242
from ax.exceptions.core import DataRequiredError, UserInputError
4343
from ax.generators.torch.botorch_moo_utils import (
44-
get_weighted_mc_objective_and_objective_thresholds,
44+
get_weighted_mc_objective,
4545
pareto_frontier_evaluator,
4646
)
4747
from ax.utils.common.constants import Keys
4848
from ax.utils.common.hash_utils import get_current_lilo_hash
4949
from ax.utils.common.logger import get_logger
50+
from ax.utils.common.sympy import (
51+
extract_metric_weights_from_objective_expr,
52+
parse_objective_expression,
53+
)
5054
from ax.utils.common.typeutils import (
5155
assert_is_instance_of_tuple,
5256
assert_is_instance_optional,
@@ -208,15 +212,18 @@ def extract_objective_thresholds(
208212
outcomes: list[str],
209213
metric_name_to_signature: Mapping[str, str],
210214
) -> npt.NDArray | None:
211-
"""Extracts objective thresholds' values, in the order of `outcomes`.
215+
"""Extracts objective thresholds' values, in the order of objectives.
216+
217+
Will return None if no objective thresholds or if the objective is single-
218+
objective. Otherwise the extracted array will have length ``n_objectives``
219+
(matching the rows of the objective weight matrix).
212220
213-
Will return None if no objective thresholds, otherwise the extracted tensor
214-
will be the same length as `outcomes`.
221+
Objectives that do not have a corresponding objective threshold will be
222+
given a threshold of NaN. We will later infer appropriate threshold values
223+
for those objectives.
215224
216-
Outcomes that are not part of an objective and the objectives that do no have
217-
a corresponding objective threshold will be given a threshold of NaN. We will
218-
later infer appropriate threshold values for the objectives that are given a
219-
threshold of NaN.
225+
The returned thresholds are maximization-aligned: for minimize objectives,
226+
the threshold is negated.
220227
221228
Args:
222229
objective_thresholds: Objective thresholds to extract values from.
@@ -225,7 +232,7 @@ def extract_objective_thresholds(
225232
metric_name_to_signature: Mapping from metric names to signatures.
226233
227234
Returns:
228-
(n,) array of thresholds
235+
``(n_objectives,)`` array of maximization-aligned thresholds, or None.
229236
"""
230237
if len(objective_thresholds) == 0:
231238
return None
@@ -250,11 +257,23 @@ def extract_objective_thresholds(
250257
f"Got {objective_thresholds=} and {objective=}."
251258
)
252259

253-
# Initialize these to be NaN to make sure that objective thresholds for
254-
# non-objective metrics are never used.
255-
obj_t = np.full(len(outcomes), float("nan"))
256-
for metric, threshold in objective_threshold_dict.items():
257-
obj_t[outcomes.index(metric)] = threshold
260+
if not objective.is_multi_objective:
261+
# Single objective — thresholds not applicable.
262+
return None
263+
264+
parsed = parse_objective_expression(objective.expression)
265+
sub_exprs = parsed if isinstance(parsed, tuple) else (parsed,)
266+
n_objectives = len(sub_exprs)
267+
obj_t = np.full(n_objectives, float("nan"))
268+
for i, sub_expr in enumerate(sub_exprs):
269+
sub_mw = extract_metric_weights_from_objective_expr(sub_expr)
270+
if len(sub_mw) > 1:
271+
continue # Scalarized sub-objective — NaN, will be inferred later.
272+
name, weight = sub_mw[0]
273+
sig = metric_name_to_signature[name]
274+
if sig in objective_threshold_dict:
275+
sign = 1.0 if weight > 0 else -1.0
276+
obj_t[i] = sign * objective_threshold_dict[sig]
258277
return obj_t
259278

260279

@@ -769,10 +788,8 @@ def pareto_frontier(
769788
if obj_t is None:
770789
return frontier_observations
771790

772-
# Apply appropriate weights and thresholds
773-
obj, obj_t = get_weighted_mc_objective_and_objective_thresholds(
774-
objective_weights=obj_w, objective_thresholds=obj_t
775-
)
791+
# Apply appropriate weights
792+
obj = get_weighted_mc_objective(objective_weights=obj_w)
776793
f_t = obj(f)
777794

778795
# Compute individual hypervolumes by taking the difference between the observation
@@ -937,10 +954,9 @@ def hypervolume(
937954
dtype=torch.bool,
938955
device=f.device,
939956
)
940-
# Apply appropriate weights and thresholds
941-
obj, obj_t = get_weighted_mc_objective_and_objective_thresholds(
942-
objective_weights=obj_w, objective_thresholds=none_throws(obj_t)
943-
)
957+
# Apply appropriate weights
958+
obj = get_weighted_mc_objective(objective_weights=obj_w)
959+
obj_t = none_throws(obj_t)
944960
f_t = obj(f)
945961
obj_mask = (obj_w != 0).any(dim=0).nonzero().view(-1)
946962
selected_metrics_mask = selected_metrics_mask[obj_mask]

ax/adapter/tests/test_torch_moo_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ def helper_test_pareto_frontier(
199199
)
200200
)
201201
self.assertTrue(obj_t is not None)
202+
# Thresholds are now (n_objectives,) and maximization-aligned.
203+
# LEQ thresholds with bound=5.0 become -5.0 after sign flip.
202204
self.assertTrue(
203-
torch.equal(
204-
none_throws(obj_t)[:2], torch.full((2,), 5.0, dtype=torch.double)
205-
)
205+
torch.equal(none_throws(obj_t), torch.full((2,), -5.0, dtype=torch.double))
206206
)
207207
observed_frontier2 = pareto_frontier(
208208
adapter=adapter,

ax/adapter/tests/test_utils.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_extract_objective_thresholds(self) -> None:
140140
for i, name in enumerate(outcomes[:3])
141141
]
142142

143-
# None of no thresholds
143+
# None if no thresholds
144144
self.assertIsNone(
145145
extract_objective_thresholds(
146146
objective_thresholds=[],
@@ -150,17 +150,17 @@ def test_extract_objective_thresholds(self) -> None:
150150
)
151151
)
152152

153-
# Working case
153+
# Working case: 3 objectives (all maximize), shape is (3,)
154154
obj_t = extract_objective_thresholds(
155155
objective_thresholds=objective_thresholds,
156156
objective=objective,
157157
outcomes=outcomes,
158158
metric_name_to_signature=metric_name_to_signature,
159159
)
160-
expected_obj_t_not_nan = np.array([2.0, 3.0, 4.0])
161-
self.assertTrue(np.array_equal(obj_t[:3], expected_obj_t_not_nan[:3]))
162-
self.assertTrue(np.isnan(obj_t[-1]))
163-
self.assertEqual(obj_t.shape[0], 4)
160+
# All maximize, so thresholds are unchanged (sign = +1).
161+
expected_obj_t = np.array([2.0, 3.0, 4.0])
162+
self.assertTrue(np.array_equal(obj_t, expected_obj_t))
163+
self.assertEqual(obj_t.shape[0], 3)
164164

165165
# Returns NaN for objectives without a threshold.
166166
obj_t = extract_objective_thresholds(
@@ -169,8 +169,9 @@ def test_extract_objective_thresholds(self) -> None:
169169
outcomes=outcomes,
170170
metric_name_to_signature=metric_name_to_signature,
171171
)
172-
self.assertTrue(np.array_equal(obj_t[:2], expected_obj_t_not_nan[:2]))
173-
self.assertTrue(np.isnan(obj_t[-2:]).all())
172+
self.assertTrue(np.array_equal(obj_t[:2], expected_obj_t[:2]))
173+
self.assertTrue(np.isnan(obj_t[2]))
174+
self.assertEqual(obj_t.shape[0], 3)
174175

175176
# Fails if a threshold does not have a corresponding metric.
176177
objective2 = Objective(expression="m1")
@@ -182,16 +183,48 @@ def test_extract_objective_thresholds(self) -> None:
182183
metric_name_to_signature=metric_name_to_signature,
183184
)
184185

185-
# Works with a single objective, single threshold
186+
# Single objective returns None.
187+
self.assertIsNone(
188+
extract_objective_thresholds(
189+
objective_thresholds=objective_thresholds[:1],
190+
objective=objective2,
191+
outcomes=outcomes,
192+
metric_name_to_signature=metric_name_to_signature,
193+
)
194+
)
195+
196+
# Maximize-alignment: minimize objectives get negated thresholds.
197+
objective_with_min = MultiObjective(
198+
objectives=[
199+
Objective(metric=Metric("m1"), minimize=False),
200+
Objective(metric=Metric("m2"), minimize=True),
201+
]
202+
)
203+
obj_thresholds_for_min = [
204+
ObjectiveThreshold(
205+
metric=Metric("m1"),
206+
op=ComparisonOp.LEQ,
207+
bound=2.0,
208+
relative=False,
209+
),
210+
ObjectiveThreshold(
211+
metric=Metric("m2"),
212+
op=ComparisonOp.LEQ,
213+
bound=3.0,
214+
relative=False,
215+
),
216+
]
186217
obj_t = extract_objective_thresholds(
187-
objective_thresholds=objective_thresholds[:1],
188-
objective=objective2,
218+
objective_thresholds=obj_thresholds_for_min,
219+
objective=objective_with_min,
189220
outcomes=outcomes,
190221
metric_name_to_signature=metric_name_to_signature,
191222
)
223+
# m1 maximize: sign=+1, threshold=2.0 → 2.0
224+
# m2 minimize: sign=-1, threshold=3.0 → -3.0
225+
self.assertEqual(obj_t.shape[0], 2)
192226
self.assertEqual(obj_t[0], 2.0)
193-
self.assertTrue(np.all(np.isnan(obj_t[1:])))
194-
self.assertEqual(obj_t.shape[0], 4)
227+
self.assertEqual(obj_t[1], -3.0)
195228

196229
# Fails if relative
197230
objective_thresholds[2] = ObjectiveThreshold(

ax/adapter/torch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,12 +1143,14 @@ def _untransform_objective_thresholds(
11431143
"""
11441144
obj_indices, obj_weights = extract_objectives(objective_weights)
11451145
thresholds = []
1146-
for idx, w in zip(obj_indices, obj_weights):
1146+
for i, (idx, w) in enumerate(zip(obj_indices, obj_weights)):
11471147
sign = torch.sign(w)
1148+
# Thresholds are maximization-aligned; undo sign flip to get raw bound.
1149+
raw_bound = float(sign * objective_thresholds[i].item())
11481150
thresholds.append(
11491151
ObjectiveThreshold(
11501152
metric=opt_config_metrics[self.outcomes[idx]],
1151-
bound=float(objective_thresholds[idx].item()),
1153+
bound=raw_bound,
11521154
relative=False,
11531155
op=ComparisonOp.LEQ if sign < 0 else ComparisonOp.GEQ,
11541156
)

ax/generators/tests/test_botorch_moo_utils.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from unittest import mock
1212
from warnings import catch_warnings, simplefilter
1313

14-
import numpy as np
1514
import torch
1615
from ax.core.search_space import SearchSpaceDigest
1716
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
1817
from ax.generators.torch.botorch_moo_utils import (
19-
get_weighted_mc_objective_and_objective_thresholds,
18+
get_weighted_mc_objective,
2019
infer_objective_thresholds,
2120
pareto_frontier_evaluator,
2221
)
@@ -68,7 +67,8 @@ def setUp(self) -> None:
6867
]
6968
)
7069
self.Yvar = torch.zeros(5, 3)
71-
self.objective_thresholds = torch.tensor([0.5, 1.5, float("nan")])
70+
# Thresholds are (n_objectives,) in maximization-aligned space.
71+
self.objective_thresholds = torch.tensor([0.5, 1.5])
7272
self.objective_weights = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
7373

7474
def test_pareto_frontier_raise_error_when_missing_data(self) -> None:
@@ -110,11 +110,12 @@ def test_pareto_frontier_evaluator_raw(self) -> None:
110110
self.assertAllClose(expected_cov, cov)
111111
self.assertTrue(torch.equal(torch.arange(2, 5), indx))
112112

113-
# Change objective_weights so goal is to minimize b
113+
# Change objective_weights so goal is to minimize b.
114+
# Thresholds in maximization-aligned space: [0.5, -1.5].
114115
Y, cov, indx = pareto_frontier_evaluator(
115116
model=model,
116117
objective_weights=torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 0.0]]),
117-
objective_thresholds=self.objective_thresholds,
118+
objective_thresholds=torch.tensor([0.5, -1.5]),
118119
Y=self.Y,
119120
Yvar=Yvar,
120121
)
@@ -213,19 +214,13 @@ def test_pareto_frontier_evaluator_with_nan(self) -> None:
213214

214215

215216
class BotorchMOOUtilsTest(TestCase):
216-
def test_get_weighted_mc_objective_and_objective_thresholds(self) -> None:
217+
def test_get_weighted_mc_objective(self) -> None:
217218
objective_weights = torch.tensor([[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]])
218-
objective_thresholds = torch.arange(4, dtype=torch.float)
219-
(
220-
weighted_obj,
221-
new_obj_thresholds,
222-
) = get_weighted_mc_objective_and_objective_thresholds(
219+
weighted_obj = get_weighted_mc_objective(
223220
objective_weights=objective_weights,
224-
objective_thresholds=objective_thresholds,
225221
)
226222
self.assertTrue(torch.equal(weighted_obj.weights, torch.tensor([1.0, 1.0])))
227223
self.assertEqual(weighted_obj.outcomes.tolist(), [1, 3])
228-
self.assertTrue(torch.equal(new_obj_thresholds, objective_thresholds[[1, 3]]))
229224

230225
# test infer objective thresholds alone
231226
@mock.patch( # pyre-ignore
@@ -255,6 +250,10 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
255250
objective_weights = torch.tensor(
256251
[[-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], **tkwargs
257252
)
253+
# Expected: infer_reference_point returns (n_objectives,) in
254+
# maximization-aligned space. With pareto_Y=[[-9, -3]] and
255+
# scale=0.1, the result is [-9.9, -3.3].
256+
expected_thresholds = torch.tensor([-9.9, -3.3], **tkwargs)
258257
with ExitStack() as es:
259258
_mock_infer_reference_point = es.enter_context(
260259
mock.patch(
@@ -282,10 +281,9 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
282281
torch.tensor([[-9.0, -3.0]], **tkwargs),
283282
)
284283
)
285-
self.assertTrue(
286-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
287-
)
288-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
284+
# Result is (n_objectives,) maximization-aligned.
285+
self.assertEqual(obj_thresholds.shape[0], 2)
286+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
289287

290288
# test subset_model without subset_idcs
291289
with mock.patch.object(model, "posterior", return_value=posterior):
@@ -295,10 +293,8 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
295293
outcome_constraints=outcome_constraints,
296294
X_observed=Xs[0],
297295
)
298-
self.assertTrue(
299-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
300-
)
301-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
296+
self.assertEqual(obj_thresholds.shape[0], 2)
297+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
302298
# test passing subset_idcs
303299
subset_idcs = torch.tensor(
304300
[0, 1], dtype=torch.long, device=tkwargs["device"]
@@ -312,10 +308,8 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
312308
X_observed=Xs[0],
313309
subset_idcs=subset_idcs,
314310
)
315-
self.assertTrue(
316-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
317-
)
318-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
311+
self.assertEqual(obj_thresholds.shape[0], 2)
312+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
319313
# test without subsetting (e.g. if there are
320314
# 3 metrics for 2 objectives + 1 outcome constraint)
321315
outcome_constraints = (
@@ -350,10 +344,8 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
350344
X_observed=Xs[0],
351345
outcome_constraints=outcome_constraints,
352346
)
353-
self.assertTrue(
354-
torch.equal(obj_thresholds[:2], torch.tensor([9.9, 3.3], **tkwargs))
355-
)
356-
self.assertTrue(np.isnan(obj_thresholds[2].item()))
347+
self.assertEqual(obj_thresholds.shape[0], 2)
348+
self.assertTrue(torch.equal(obj_thresholds, expected_thresholds))
357349

358350
def test_infer_objective_thresholds_cuda(self) -> None:
359351
if torch.cuda.is_available():

0 commit comments

Comments
 (0)