Skip to content

Commit 8feb6c8

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Allow passing Y_samples directly in MARS.set_baseline_Y (#1364)
Summary: Pull Request resolved: #1364 We currently sample from the model to generate `Y` to compute the `MVaR` of, and set the Pareto subset of MVaR as the `baseline_Y`. This adds an option to skip the sampling bit and provide the `Y` directly. This saves us compute in cases where we may have access to pre-computed samples and may not want to re-sample. Reviewed By: esantorella Differential Revision: D38950165 fbshipit-source-id: 8eb24b086a03ad0035e4018f54975d057eaa1c95
1 parent 2dda2e6 commit 8feb6c8

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

botorch/acquisition/multi_objective/multi_output_risk_measures.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from botorch.acquisition.risk_measures import CVaR, RiskMeasureMCObjective, VaR
4040
from botorch.exceptions.errors import UnsupportedError
41+
from botorch.exceptions.warnings import BotorchWarning
4142
from botorch.models.model import Model
4243
from botorch.utils.multi_objective.pareto import is_non_dominated
4344
from botorch.utils.transforms import normalize
@@ -568,17 +569,34 @@ def __init__(
568569
self.mvar = MVaR(n_w=self.n_w, alpha=self.alpha)
569570
self._chebyshev_objective = None
570571

571-
def set_baseline_Y(self, model: Model, X_baseline: Tensor) -> None:
572+
def set_baseline_Y(
573+
self,
574+
model: Optional[Model],
575+
X_baseline: Optional[Tensor],
576+
Y_samples: Optional[Tensor] = None,
577+
) -> None:
572578
r"""Set the `baseline_Y` based on the MVaR predictions of the `model`
573579
for `X_baseline`.
574580
575581
Args:
576582
model: The model being used for MARS optimization. Must have a compatible
577-
`InputPerturbation` transform attached.
583+
`InputPerturbation` transform attached. Ignored if `Y_samples` is given.
578584
X_baseline: An `n x d`-dim tensor of previously evaluated points.
585+
Ignored if `Y_samples` is given.
586+
Y_samples: An optional `(n * n_w) x d`-dim tensor of predictions. If given,
587+
instead of sampling from the model, these are used.
579588
"""
580-
with torch.no_grad():
581-
Y = model.posterior(X_baseline.unsqueeze(-2)).mean.squeeze(-2)
589+
if Y_samples is None:
590+
with torch.no_grad():
591+
Y = model.posterior(X_baseline.unsqueeze(-2)).mean.squeeze(-2)
592+
else:
593+
if model is not None or X_baseline is not None:
594+
warnings.warn(
595+
"`model` and `X_baseline` are ignored when `Y_samples` is "
596+
"provided to `MARS.set_baseline_Y`.",
597+
BotorchWarning,
598+
)
599+
Y = Y_samples
582600
Y = self.preprocessing_function(Y)
583601
Y = self.mvar(Y).view(-1, Y.shape[-1])
584602
Y = Y[is_non_dominated(Y)]

test/acquisition/multi_objective/test_multi_output_risk_measures.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Optional
89

910
import torch
11+
from botorch import settings
1012
from botorch.acquisition.multi_objective.multi_output_risk_measures import (
1113
IndependentCVaR,
1214
IndependentVaR,
@@ -18,6 +20,7 @@
1820
)
1921
from botorch.acquisition.multi_objective.objective import IdentityMCMultiOutputObjective
2022
from botorch.exceptions.errors import UnsupportedError
23+
from botorch.exceptions.warnings import BotorchWarning
2124
from botorch.models.deterministic import GenericDeterministicModel
2225
from botorch.models.transforms.input import InputPerturbation
2326
from botorch.utils.multi_objective.pareto import is_non_dominated
@@ -475,20 +478,24 @@ def test_set_baseline_Y(self):
475478
)
476479
model = GenericDeterministicModel(f=lambda X: X, num_outputs=2)
477480
model.input_transform = perturbation
478-
mars.set_baseline_Y(
479-
model=model, X_baseline=torch.tensor([[0.0, 0.0], [1.0, 1.0]])
480-
)
481+
X_baseline = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
482+
mars.set_baseline_Y(model=model, X_baseline=X_baseline)
483+
self.assertTrue(torch.equal(mars.baseline_Y, torch.tensor([[1.5, 1.5]])))
484+
# With Y_samples.
485+
mars._baseline_Y = None
486+
Y_samples = model.posterior(X_baseline).mean
487+
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
488+
mars.set_baseline_Y(model=model, X_baseline=X_baseline, Y_samples=Y_samples)
481489
self.assertTrue(torch.equal(mars.baseline_Y, torch.tensor([[1.5, 1.5]])))
490+
self.assertTrue(any(w.category == BotorchWarning for w in ws))
482491
# With pre-processing function.
483492
mars = MARS(
484493
alpha=0.5,
485494
n_w=3,
486495
chebyshev_weights=[0.5, 0.5],
487496
preprocessing_function=lambda Y: -Y,
488497
)
489-
mars.set_baseline_Y(
490-
model=model, X_baseline=torch.tensor([[0.0, 0.0], [1.0, 1.0]])
491-
)
498+
mars.set_baseline_Y(model=model, X_baseline=X_baseline)
492499
self.assertTrue(torch.equal(mars.baseline_Y, torch.tensor([[-0.5, -0.5]])))
493500

494501
def test_get_Y_normalization_bounds(self):

0 commit comments

Comments
 (0)