Skip to content

Commit 6f92f2f

Browse files
esantorellafacebook-github-bot
authored andcommitted
BoxDecomposition cleanup (#1490)
Summary: Pull Request resolved: #1490 - Change `compute_hypervolume` so that each BoxDecomposition subclass uses shared logic for the no-data case - [debatable] When `Y` is `None`, functions of Y like `box_decomp._neg_Y` are `None` rather than being unset attributes, so we do "if self._neg_Y is None" rather than catching an AttributeError. This makes catching type errors easier since otherwise Pyre is unhappy about references to the potentially-uninitialized attribute. - Took out unnecessary "register_buffer" calls (this happens automatically with `torch.nn.Module.setattr`) Reviewed By: SebastianAment Differential Revision: D41172490 fbshipit-source-id: 4f5899b0deaa8ea1d250a3d840dd1a80f930d500
1 parent 4a1f8ed commit 6f92f2f

File tree

4 files changed

+88
-103
lines changed

4 files changed

+88
-103
lines changed

botorch/utils/multi_objective/box_decompositions/box_decomposition.py

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,18 @@ def __init__(
5151
Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
5252
"""
5353
super().__init__()
54-
self.register_buffer("_neg_ref_point", -ref_point)
55-
self.register_buffer("sort", torch.tensor(sort, dtype=torch.bool))
54+
self._neg_ref_point = -ref_point
55+
self.sort = torch.tensor(sort, dtype=torch.bool)
5656
self.num_outcomes = ref_point.shape[-1]
57+
5758
if Y is not None:
58-
self._update_neg_Y(Y=Y)
59-
self.reset()
59+
self._neg_Y = -Y
60+
self._validate_inputs()
61+
self._neg_pareto_Y = self._compute_pareto_Y()
62+
self.partition_space()
63+
else:
64+
self._neg_Y = None
65+
self._neg_pareto_Y = None
6066

6167
@property
6268
def pareto_Y(self) -> Tensor:
@@ -65,10 +71,9 @@ def pareto_Y(self) -> Tensor:
6571
Returns:
6672
A `n_pareto x m`-dim tensor of outcomes.
6773
"""
68-
try:
74+
if self._neg_pareto_Y is not None:
6975
return -self._neg_pareto_Y
70-
except AttributeError:
71-
raise BotorchError("pareto_Y has not been initialized")
76+
raise BotorchError("pareto_Y has not been initialized")
7277

7378
@property
7479
def ref_point(self) -> Tensor:
@@ -86,41 +91,47 @@ def Y(self) -> Tensor:
8691
Returns:
8792
A `n x m`-dim tensor of outcomes.
8893
"""
89-
return -self._neg_Y
94+
if self._neg_Y is not None:
95+
return -self._neg_Y
96+
raise BotorchError("Y data has not been initialized")
97+
98+
def _compute_pareto_Y(self) -> Tensor:
99+
if self._neg_Y is None:
100+
raise BotorchError("Y data has not been initialized")
101+
# is_non_dominated assumes maximization
102+
if self._neg_Y.shape[-2] == 0:
103+
return self._neg_Y
104+
# assumes maximization
105+
pareto_Y = -_pad_batch_pareto_frontier(
106+
Y=self.Y,
107+
ref_point=_expand_ref_point(
108+
ref_point=self.ref_point, batch_shape=self.batch_shape
109+
),
110+
)
111+
if not self.sort:
112+
return pareto_Y
113+
# sort by first objective
114+
if len(self.batch_shape) > 0:
115+
pareto_Y = pareto_Y.gather(
116+
index=torch.argsort(pareto_Y[..., :1], dim=-2).expand(pareto_Y.shape),
117+
dim=-2,
118+
)
119+
else:
120+
pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])]
121+
return pareto_Y
90122

91123
def _reset_pareto_Y(self) -> bool:
92124
r"""Update the non-dominated front.
93125
94126
Returns:
95127
A boolean indicating whether the Pareto frontier has changed.
96128
"""
97-
# is_non_dominated assumes maximization
98-
if self._neg_Y.shape[-2] == 0:
99-
pareto_Y = self._neg_Y
100-
else:
101-
# assumes maximization
102-
pareto_Y = -_pad_batch_pareto_frontier(
103-
Y=self.Y,
104-
ref_point=_expand_ref_point(
105-
ref_point=self.ref_point, batch_shape=self.batch_shape
106-
),
107-
)
108-
if self.sort:
109-
# sort by first objective
110-
if len(self.batch_shape) > 0:
111-
pareto_Y = pareto_Y.gather(
112-
index=torch.argsort(pareto_Y[..., :1], dim=-2).expand(
113-
pareto_Y.shape
114-
),
115-
dim=-2,
116-
)
117-
else:
118-
pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])]
129+
pareto_Y = self._compute_pareto_Y()
119130

120-
if not hasattr(self, "_neg_pareto_Y") or not torch.equal(
131+
if (self._neg_pareto_Y is None) or not torch.equal(
121132
pareto_Y, self._neg_pareto_Y
122133
):
123-
self.register_buffer("_neg_pareto_Y", pareto_Y)
134+
self._neg_pareto_Y = pareto_Y
124135
return True
125136
return False
126137

@@ -139,13 +150,12 @@ def _partition_space_2d(self) -> None:
139150
raise NotImplementedError
140151

141152
@abstractmethod
142-
def _partition_space(self):
153+
def _partition_space(self) -> None:
143154
r"""Partition the non-dominated space into disjoint hypercells.
144155
145156
This method supports an arbitrary number of outcomes, but is
146157
less efficient than `partition_space_2d` for the 2-outcome case.
147158
"""
148-
pass # pragma: no cover
149159

150160
@abstractmethod
151161
def get_hypercell_bounds(self) -> Tensor:
@@ -155,7 +165,6 @@ def get_hypercell_bounds(self) -> Tensor:
155165
A `2 x num_cells x num_outcomes`-dim tensor containing the
156166
lower and upper vertices bounding each hypercell.
157167
"""
158-
pass # pragma: no cover
159168

160169
def _update_neg_Y(self, Y: Tensor) -> bool:
161170
r"""Update the set of outcomes.
@@ -164,12 +173,11 @@ def _update_neg_Y(self, Y: Tensor) -> bool:
164173
A boolean indicating if _neg_Y was initialized.
165174
"""
166175
# multiply by -1, since internally we minimize.
167-
try:
176+
if self._neg_Y is not None:
168177
self._neg_Y = torch.cat([self._neg_Y, -Y], dim=-2)
169178
return False
170-
except AttributeError:
171-
self.register_buffer("_neg_Y", -Y)
172-
return True
179+
self._neg_Y = -Y
180+
return True
173181

174182
def update(self, Y: Tensor) -> None:
175183
r"""Update non-dominated front and decomposition.
@@ -183,8 +191,7 @@ def update(self, Y: Tensor) -> None:
183191
self._update_neg_Y(Y=Y)
184192
self.reset()
185193

186-
def reset(self) -> None:
187-
r"""Reset non-dominated front and decomposition."""
194+
def _validate_inputs(self) -> None:
188195
self.batch_shape = self.Y.shape[:-2]
189196
self.num_outcomes = self.Y.shape[-1]
190197
if len(self.batch_shape) > 1:
@@ -198,20 +205,36 @@ def reset(self) -> None:
198205
f"{type(self).__name__} only supports a batched box "
199206
f"decompositions in the 2-objective setting."
200207
)
208+
209+
def reset(self) -> None:
210+
r"""Reset non-dominated front and decomposition."""
211+
self._validate_inputs()
201212
is_new_pareto = self._reset_pareto_Y()
202213
# Update decomposition if the Pareto front changed
203214
if is_new_pareto:
204215
self.partition_space()
205216

206217
@abstractmethod
218+
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
219+
"""Compute hypervolume for the case that there is data in self._neg_pareto_Y."""
220+
207221
def compute_hypervolume(self) -> Tensor:
208222
r"""Compute hypervolume that is dominated by the Pareto Froniter.
209223
210224
Returns:
211225
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
212226
each Pareto frontier.
213227
"""
214-
pass # pragma: no cover
228+
if self._neg_pareto_Y is None:
229+
return torch.tensor(0.0)
230+
231+
if self._neg_pareto_Y.shape[-2] == 0:
232+
return torch.zeros(
233+
self._neg_pareto_Y.shape[:-2],
234+
dtype=self._neg_pareto_Y.dtype,
235+
device=self._neg_pareto_Y.device,
236+
)
237+
return self._compute_hypervolume_if_y_has_data()
215238

216239

217240
class FastPartitioning(BoxDecomposition, ABC):

botorch/utils/multi_objective/box_decompositions/dominated.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from __future__ import annotations
1010

11-
import torch
1211
from botorch.utils.multi_objective.box_decompositions.box_decomposition import (
1312
FastPartitioning,
1413
)
@@ -39,7 +38,7 @@ def _partition_space_2d(self) -> None:
3938
pareto_Y_sorted=self.pareto_Y.flip(-2),
4039
ref_point=self.ref_point,
4140
)
42-
self.register_buffer("hypercell_bounds", cell_bounds)
41+
self.hypercell_bounds = cell_bounds
4342

4443
def _get_partitioning(self) -> None:
4544
r"""Get the bounds of each hypercell in the decomposition."""
@@ -49,22 +48,13 @@ def _get_partitioning(self) -> None:
4948
cell_bounds = -minimization_cell_bounds.flip(0)
5049
self.register_buffer("hypercell_bounds", cell_bounds)
5150

52-
def compute_hypervolume(self) -> Tensor:
51+
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
5352
r"""Compute hypervolume that is dominated by the Pareto Frontier.
5453
5554
Returns:
5655
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
5756
each Pareto frontier.
5857
"""
59-
if not hasattr(self, "_neg_pareto_Y"):
60-
return torch.tensor(0.0).to(self._neg_ref_point)
61-
62-
if self._neg_pareto_Y.shape[-2] == 0:
63-
return torch.zeros(
64-
self._neg_pareto_Y.shape[:-2],
65-
dtype=self._neg_pareto_Y.dtype,
66-
device=self._neg_pareto_Y.device,
67-
)
6858
return (
6959
(self.hypercell_bounds[1] - self.hypercell_bounds[0])
7060
.prod(dim=-1)
@@ -77,4 +67,4 @@ def _get_single_cell(self) -> None:
7767
cell_bounds = self.ref_point.expand(
7868
2, *self._neg_pareto_Y.shape[:-2], 1, self.num_outcomes
7969
).clone()
80-
self.register_buffer("hypercell_bounds", cell_bounds)
70+
self.hypercell_bounds = cell_bounds

botorch/utils/multi_objective/box_decompositions/non_dominated.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,8 @@ def _partition_space(self) -> None:
9898
# hypercells contains the indices of the (augmented) Pareto front
9999
# that specify that bounds of the each hypercell.
100100
# It is a `2 x num_cells x m`-dim tensor
101-
self.register_buffer(
102-
"hypercells",
103-
torch.empty(
104-
2, 0, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
105-
),
101+
self.hypercells = torch.empty(
102+
2, 0, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
106103
)
107104
outcome_idxr = torch.arange(
108105
self.num_outcomes, dtype=torch.long, device=self._neg_Y.device
@@ -216,7 +213,7 @@ def _partition_space_2d(self) -> None:
216213
dim=-1,
217214
)
218215
# 2 x batch_shape x n_cells x 2
219-
self.register_buffer("hypercells", torch.stack([lower, upper], dim=0))
216+
self.hypercells = torch.stack([lower, upper], dim=0)
220217

221218
def _get_augmented_pareto_front_indices(self) -> Tensor:
222219
r"""Get indices of augmented Pareto front."""
@@ -337,25 +334,7 @@ def _get_hypercell_bounds(self, aug_pareto_Y: Tensor) -> Tensor:
337334
view_shape = (2, *self.batch_shape, num_cells, self.num_outcomes)
338335
return cell_bounds_values.view(view_shape)
339336

340-
def compute_hypervolume(self) -> Tensor:
341-
r"""Compute the hypervolume for the given reference point.
342-
343-
This method computes the hypervolume of the non-dominated space
344-
and computes the difference between the hypervolume between the
345-
ideal point and hypervolume of the non-dominated space.
346-
347-
Returns:
348-
`(batch_shape)`-dim tensor containing the dominated hypervolume.
349-
"""
350-
if not hasattr(self, "_neg_pareto_Y"):
351-
return torch.tensor(0.0).to(self._neg_ref_point)
352-
353-
if self._neg_pareto_Y.shape[-2] == 0:
354-
return torch.zeros(
355-
self._neg_pareto_Y.shape[:-2],
356-
dtype=self._neg_pareto_Y.dtype,
357-
device=self._neg_pareto_Y.device,
358-
)
337+
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
359338
ref_point = _expand_ref_point(
360339
ref_point=self.ref_point, batch_shape=self.batch_shape
361340
)
@@ -413,7 +392,7 @@ def _get_single_cell(self) -> None:
413392
device=self._neg_pareto_Y.device,
414393
)
415394
cell_bounds[0] = self.ref_point
416-
self.register_buffer("hypercell_bounds", cell_bounds)
395+
self.hypercell_bounds = cell_bounds
417396

418397
def _get_partitioning(self) -> None:
419398
r"""Compute non-dominated partitioning.
@@ -432,7 +411,7 @@ def _get_partitioning(self) -> None:
432411
device=self._neg_ref_point.device,
433412
)
434413
# initialize local upper bounds for the second minimization problem
435-
self.register_buffer("_U2", new_ref_point)
414+
self._U2 = new_ref_point
436415
# initialize defining points for the second minimization problem
437416
# use ref point for maximization as the ideal point for minimization.
438417
self._Z2 = self.ref_point.expand(
@@ -450,7 +429,7 @@ def _get_partitioning(self) -> None:
450429
cell_bounds = get_partition_bounds(
451430
Z=self._Z2, U=self._U2, ref_point=new_ref_point.view(-1)
452431
)
453-
self.register_buffer("hypercell_bounds", cell_bounds)
432+
self.hypercell_bounds = cell_bounds
454433

455434
def _partition_space_2d(self) -> None:
456435
r"""Partition the non-dominated space into disjoint hypercells.
@@ -461,23 +440,9 @@ def _partition_space_2d(self) -> None:
461440
pareto_Y_sorted=self.pareto_Y.flip(-2),
462441
ref_point=self.ref_point,
463442
)
464-
self.register_buffer("hypercell_bounds", cell_bounds)
465-
466-
def compute_hypervolume(self) -> Tensor:
467-
r"""Compute hypervolume that is dominated by the Pareto Froniter.
443+
self.hypercell_bounds = cell_bounds
468444

469-
Returns:
470-
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
471-
each Pareto frontier.
472-
"""
473-
if not hasattr(self, "_neg_pareto_Y"):
474-
return torch.tensor(0.0).to(self._neg_ref_point)
475-
if self._neg_pareto_Y.shape[-2] == 0:
476-
return torch.zeros(
477-
self._neg_pareto_Y.shape[:-2],
478-
dtype=self._neg_pareto_Y.dtype,
479-
device=self._neg_pareto_Y.device,
480-
)
445+
def _compute_hypervolume_if_y_has_data(self) -> Tensor:
481446
ideal_point = self.pareto_Y.max(dim=-2, keepdim=True).values
482447
total_volume = (
483448
(ideal_point.squeeze(-2) - self.ref_point).clamp_min(0.0).prod(dim=-1)

test/utils/multi_objective/box_decompositions/test_box_decomposition.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class DummyBoxDecomposition(BoxDecomposition):
3232
def _partition_space(self):
3333
pass
3434

35-
def compute_hypervolume(self):
35+
def _compute_hypervolume_if_y_has_data(self):
3636
pass
3737

3838
def get_hypercell_bounds(self):
@@ -66,7 +66,7 @@ def setUp(self):
6666
device=self.device,
6767
)
6868

69-
def test_box_decomposition(self):
69+
def test_box_decomposition(self) -> None:
7070
with self.assertRaises(TypeError):
7171
BoxDecomposition()
7272
for dtype, m, sort in product(
@@ -271,7 +271,7 @@ def test_fast_partitioning(self):
271271
DummyFastPartitioning(ref_point=ref_point, Y=Y.unsqueeze(0))
272272

273273

274-
class TestBoxDecomposition_Hypervolume(BotorchTestCase):
274+
class TestBoxDecomposition_no_set_up(BotorchTestCase):
275275
def helper_hypervolume(self, Box_Decomp_cls: type) -> None:
276276
"""
277277
This test should be run for each non-abstract subclass of `BoxDecomposition`.
@@ -292,7 +292,6 @@ def helper_hypervolume(self, Box_Decomp_cls: type) -> None:
292292

293293
box_decomp = Box_Decomp_cls(ref_point=ref_point, Y=Y)
294294
hv = box_decomp.compute_hypervolume()
295-
296295
self.assertEqual(hv.shape, ())
297296
self.assertTrue(torch.allclose(hv, torch.tensor(1.0)))
298297

@@ -316,3 +315,11 @@ def test_hypervolume(self) -> None:
316315
FastNondominatedPartitioning,
317316
]:
318317
self.helper_hypervolume(cl)
318+
319+
def test_uninitialized_y(self) -> None:
320+
ref_point = torch.zeros(2)
321+
box_decomp = NondominatedPartitioning(ref_point=ref_point)
322+
with self.assertRaises(BotorchError):
323+
box_decomp.Y
324+
with self.assertRaises(BotorchError):
325+
box_decomp._compute_pareto_Y()

0 commit comments

Comments
 (0)