Skip to content

Commit 91b45b7

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Remove UnstandardizeMCMultiOutputObjective and UnstandardizePosteriorTransform (#2362)
Summary: Pull Request resolved: #2362 There is no reason to use these instead of the `Standardize` outcome transform. We could add a deprecation warning and remove it at a later release, but it is hidden deep in the code base that I think it's ok to remove it outright. I couldn't find any usage anywhere. Reviewed By: Balandat Differential Revision: D58164130 fbshipit-source-id: 02be841ca0c4f7bf8c701a5e7beadaa68ab582be
1 parent 8248b78 commit 91b45b7

File tree

5 files changed

+10
-178
lines changed

5 files changed

+10
-178
lines changed

botorch/acquisition/multi_objective/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from botorch.acquisition.multi_objective.objective import (
2525
IdentityMCMultiOutputObjective,
2626
MCMultiOutputObjective,
27-
UnstandardizeMCMultiOutputObjective,
2827
WeightedMCMultiOutputObjective,
2928
)
3029
from botorch.acquisition.multi_objective.utils import (
@@ -47,6 +46,5 @@
4746
"MCMultiOutputObjective",
4847
"MultiObjectiveAnalyticAcquisitionFunction",
4948
"MultiObjectiveMCAcquisitionFunction",
50-
"UnstandardizeMCMultiOutputObjective",
5149
"WeightedMCMultiOutputObjective",
5250
]

botorch/acquisition/multi_objective/objective.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -207,49 +207,3 @@ def apply_feasibility_weights(
207207

208208
def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
209209
return self.objective(self.apply_feasibility_weights(samples), X=X)
210-
211-
212-
class UnstandardizeMCMultiOutputObjective(IdentityMCMultiOutputObjective):
213-
r"""Objective that unstandardizes the samples.
214-
215-
TODO: remove this when MultiTask models support outcome transforms.
216-
217-
Example:
218-
>>> unstd_objective = UnstandardizeMCMultiOutputObjective(Y_mean, Y_std)
219-
>>> samples = sampler(posterior)
220-
>>> objective = unstd_objective(samples)
221-
"""
222-
223-
def __init__(
224-
self, Y_mean: Tensor, Y_std: Tensor, outcomes: Optional[List[int]] = None
225-
) -> None:
226-
r"""Initialize objective.
227-
228-
Args:
229-
Y_mean: `m`-dim tensor of outcome means.
230-
Y_std: `m`-dim tensor of outcome standard deviations.
231-
outcomes: A list of `m' <= m` indices that specifies which of the `m` model
232-
outputs should be considered as the outcomes for MOO. If omitted, use
233-
all model outcomes. Typically used for constrained optimization.
234-
"""
235-
if Y_mean.ndim > 1 or Y_std.ndim > 1:
236-
raise BotorchTensorDimensionError(
237-
"Y_mean and Y_std must both be 1-dimensional, but got "
238-
f"{Y_mean.ndim} and {Y_std.ndim}"
239-
)
240-
elif outcomes is not None and len(outcomes) > Y_mean.shape[-1]:
241-
raise BotorchTensorDimensionError(
242-
f"Cannot specify more ({len(outcomes)}) outcomes than are present in "
243-
f"the normalization inputs ({Y_mean.shape[-1]})."
244-
)
245-
super().__init__(outcomes=outcomes, num_outcomes=Y_mean.shape[-1])
246-
if outcomes is not None:
247-
Y_mean = Y_mean.index_select(-1, self.outcomes.to(Y_mean.device))
248-
Y_std = Y_std.index_select(-1, self.outcomes.to(Y_mean.device))
249-
250-
self.register_buffer("Y_mean", Y_mean)
251-
self.register_buffer("Y_std", Y_std)
252-
253-
def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
254-
samples = super().forward(samples=samples)
255-
return samples * self.Y_std + self.Y_mean

botorch/acquisition/objective.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
from typing import Callable, List, Optional, TYPE_CHECKING, Union
1414

1515
import torch
16-
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
16+
from botorch.exceptions.errors import UnsupportedError
1717
from botorch.exceptions.warnings import InputDataWarning
1818
from botorch.models.model import Model
19-
from botorch.models.transforms.outcome import Standardize
2019
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
2120
from botorch.sampling import IIDNormalSampler
2221
from botorch.utils import apply_constraints
@@ -234,45 +233,6 @@ def forward(self, posterior: GPyTorchPosterior) -> GPyTorchPosterior:
234233
return GPyTorchPosterior(distribution=new_mvn)
235234

236235

237-
class UnstandardizePosteriorTransform(PosteriorTransform):
238-
r"""Posterior transform that unstandardizes the posterior.
239-
240-
TODO: remove this when MultiTask models support outcome transforms.
241-
242-
Example:
243-
>>> unstd_transform = UnstandardizePosteriorTransform(Y_mean, Y_std)
244-
>>> unstd_posterior = unstd_transform(posterior)
245-
"""
246-
247-
def __init__(self, Y_mean: Tensor, Y_std: Tensor) -> None:
248-
r"""Initialize objective.
249-
250-
Args:
251-
Y_mean: `m`-dim tensor of outcome means
252-
Y_std: `m`-dim tensor of outcome standard deviations
253-
254-
"""
255-
if Y_mean.ndim > 1 or Y_std.ndim > 1:
256-
raise BotorchTensorDimensionError(
257-
"Y_mean and Y_std must both be 1-dimensional, but got "
258-
f"{Y_mean.ndim} and {Y_std.ndim}"
259-
)
260-
super().__init__()
261-
self.outcome_transform = Standardize(m=Y_mean.shape[0]).to(Y_mean)
262-
Y_std_unsqueezed = Y_std.unsqueeze(0)
263-
self.outcome_transform.means = Y_mean.unsqueeze(0)
264-
self.outcome_transform.stdvs = Y_std_unsqueezed
265-
self.outcome_transform._stdvs_sq = Y_std_unsqueezed.pow(2)
266-
self.outcome_transform._is_trained = torch.tensor(True)
267-
self.outcome_transform.eval()
268-
269-
def evaluate(self, Y: Tensor) -> Tensor:
270-
return self.outcome_transform.untransform(Y)[0]
271-
272-
def forward(self, posterior: GPyTorchPosterior) -> Tensor:
273-
return self.outcome_transform.untransform_posterior(posterior)
274-
275-
276236
class MCAcquisitionObjective(Module, ABC):
277237
r"""Abstract base class for MC-based objectives.
278238

test/acquisition/multi_objective/test_objective.py

Lines changed: 5 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,10 @@
1414
FeasibilityWeightedMCMultiOutputObjective,
1515
IdentityMCMultiOutputObjective,
1616
MCMultiOutputObjective,
17-
UnstandardizeMCMultiOutputObjective,
1817
WeightedMCMultiOutputObjective,
1918
)
20-
from botorch.acquisition.objective import (
21-
IdentityMCObjective,
22-
UnstandardizePosteriorTransform,
23-
)
19+
from botorch.acquisition.objective import IdentityMCObjective
2420
from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError
25-
from botorch.models.transforms.outcome import Standardize
2621
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
2722

2823

@@ -37,14 +32,17 @@ def test_identity_mc_multi_output_objective(self):
3732
objective = IdentityMCMultiOutputObjective()
3833
with self.assertRaises(BotorchTensorDimensionError):
3934
IdentityMCMultiOutputObjective(outcomes=[0])
40-
# test negative outcome without specifying num_outcomes
35+
# Test negative outcome without specifying num_outcomes.
4136
with self.assertRaises(BotorchError):
4237
IdentityMCMultiOutputObjective(outcomes=[0, -1])
4338
for batch_shape, m, dtype in itertools.product(
4439
([], [3]), (2, 3), (torch.float, torch.double)
4540
):
4641
samples = torch.rand(*batch_shape, 2, m, device=self.device, dtype=dtype)
4742
self.assertTrue(torch.equal(objective(samples), samples))
43+
# Test negative outcome with num_outcomes.
44+
objective = IdentityMCMultiOutputObjective(outcomes=[0, -1], num_outcomes=3)
45+
self.assertEqual(objective.outcomes.tolist(), [0, 2])
4846

4947

5048
class TestWeightedMCMultiOutputObjective(BotorchTestCase):
@@ -138,73 +136,3 @@ def test_feasibility_weighted_mc_multi_output_objective(self):
138136
X_baseline=X,
139137
constraint_idcs=[1, -1],
140138
)
141-
142-
143-
class TestUnstandardizeMultiOutputObjective(BotorchTestCase):
144-
def test_unstandardize_mo_objective(self):
145-
Y_mean = torch.ones(2)
146-
Y_std = torch.ones(2)
147-
with self.assertRaises(BotorchTensorDimensionError):
148-
UnstandardizeMCMultiOutputObjective(
149-
Y_mean=Y_mean, Y_std=Y_std, outcomes=[0, 1, 2]
150-
)
151-
for objective_class in (
152-
UnstandardizeMCMultiOutputObjective,
153-
UnstandardizePosteriorTransform,
154-
):
155-
with self.assertRaises(BotorchTensorDimensionError):
156-
objective_class(Y_mean=Y_mean.unsqueeze(0), Y_std=Y_std)
157-
with self.assertRaises(BotorchTensorDimensionError):
158-
objective_class(Y_mean=Y_mean, Y_std=Y_std.unsqueeze(0))
159-
objective = objective_class(Y_mean=Y_mean, Y_std=Y_std)
160-
for batch_shape, m, outcomes, dtype in itertools.product(
161-
([], [3]), (2, 3), (None, [-2, -1]), (torch.float, torch.double)
162-
):
163-
Y_mean = torch.rand(m, dtype=dtype, device=self.device)
164-
Y_std = torch.rand(m, dtype=dtype, device=self.device).clamp_min(1e-3)
165-
kwargs = {}
166-
if objective_class == UnstandardizeMCMultiOutputObjective:
167-
kwargs["outcomes"] = outcomes
168-
objective = objective_class(Y_mean=Y_mean, Y_std=Y_std, **kwargs)
169-
if objective_class == UnstandardizePosteriorTransform:
170-
objective = objective_class(Y_mean=Y_mean, Y_std=Y_std)
171-
if outcomes is None:
172-
# passing outcomes is not currently supported
173-
mean = torch.rand(2, m, dtype=dtype, device=self.device)
174-
variance = variance = torch.rand(
175-
2, m, dtype=dtype, device=self.device
176-
)
177-
mock_posterior = MockPosterior(mean=mean, variance=variance)
178-
tf_posterior = objective(mock_posterior)
179-
tf = Standardize(m=m)
180-
tf.means = Y_mean
181-
tf.stdvs = Y_std
182-
tf._stdvs_sq = Y_std.pow(2)
183-
tf._is_trained = torch.tensor(True)
184-
tf.eval()
185-
expected_posterior = tf.untransform_posterior(mock_posterior)
186-
self.assertTrue(
187-
torch.equal(tf_posterior.mean, expected_posterior.mean)
188-
)
189-
self.assertTrue(
190-
torch.equal(
191-
tf_posterior.variance, expected_posterior.variance
192-
)
193-
)
194-
# testing evaluate specifically
195-
if objective_class == UnstandardizePosteriorTransform:
196-
Y = torch.randn_like(Y_mean) + Y_mean
197-
val = objective.evaluate(Y)
198-
val_expected = Y_mean + Y * Y_std
199-
self.assertTrue(torch.allclose(val, val_expected))
200-
else:
201-
202-
samples = torch.rand(
203-
*batch_shape, 2, m, dtype=dtype, device=self.device
204-
)
205-
obj_expected = samples * Y_std.to(dtype=dtype) + Y_mean.to(
206-
dtype=dtype
207-
)
208-
if outcomes is not None:
209-
obj_expected = obj_expected[..., outcomes]
210-
self.assertTrue(torch.equal(objective(samples), obj_expected))

test/acquisition/multi_objective/test_utils.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import torch
1212
from botorch.acquisition.multi_objective.objective import (
13+
IdentityMCMultiOutputObjective,
1314
MCMultiOutputObjective,
14-
UnstandardizeMCMultiOutputObjective,
1515
)
1616
from botorch.acquisition.multi_objective.utils import (
1717
compute_sample_box_decomposition,
@@ -92,16 +92,8 @@ def test_prune_inferior_points_multi_objective(self):
9292
X_pruned = prune_inferior_points_multi_objective(
9393
model=mm, X=X, ref_point=ref_point
9494
)
95-
self.assertTrue(torch.equal(X_pruned, X[[-1]]))
96-
# test unstd objective
97-
unstd_obj = UnstandardizeMCMultiOutputObjective(
98-
Y_mean=samples.mean(dim=0), Y_std=samples.std(dim=0), outcomes=[0, 1]
99-
)
100-
X_pruned = prune_inferior_points_multi_objective(
101-
model=mm, X=X, ref_point=ref_point, objective=unstd_obj
102-
)
103-
self.assertTrue(torch.equal(X_pruned, X[[-1]]))
10495
# test constraints
96+
objective = IdentityMCMultiOutputObjective(outcomes=[0, 1])
10597
samples_constrained = torch.tensor(
10698
[[1.0, 2.0, -1.0], [2.0, 1.0, -1.0], [3.0, 4.0, 1.0]], **tkwargs
10799
)
@@ -110,7 +102,7 @@ def test_prune_inferior_points_multi_objective(self):
110102
model=mm_constrained,
111103
X=X,
112104
ref_point=ref_point,
113-
objective=unstd_obj,
105+
objective=objective,
114106
constraints=[lambda Y: Y[..., -1]],
115107
)
116108
self.assertTrue(torch.equal(X_pruned, X[:2]))
@@ -161,7 +153,7 @@ def test_prune_inferior_points_multi_objective(self):
161153
model=mm,
162154
X=X,
163155
ref_point=ref_point,
164-
objective=unstd_obj,
156+
objective=objective,
165157
constraints=[lambda Y: Y[..., -1] - 3.0],
166158
marginalize_dim=-3,
167159
)

0 commit comments

Comments
 (0)