Skip to content

Commit e07f5bd

Browse files
Daniel Jiangfacebook-github-bot
authored andcommitted
mves sampled costs (#352)
Summary: Pull Request resolved: #352 make sure MVES can support sampled costs like KG (ports over the logic from KG) Reviewed By: Balandat Differential Revision: D19203233 fbshipit-source-id: 23bd742753fd78f8539409b18c9b74d4b83c6402
1 parent 86b90f0 commit e07f5bd

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

botorch/acquisition/max_value_entropy_search.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
arXiv:1901.08275v1, 2019
2222
"""
2323

24+
from copy import deepcopy
2425
from math import log
2526
from typing import Callable, Optional
2627

@@ -297,7 +298,6 @@ def _compute_information_gain(
297298
H1_hat = H1_bar - beta * (H0_bar - H0)
298299
ig = H0 - H1_hat # batch_shape x num_fantasies
299300
ig = ig.permute(-1, *range(ig.dim() - 1)) # num_fantasies x batch_shape
300-
301301
return ig
302302

303303

@@ -382,13 +382,29 @@ def __init__(
382382
self.cost_aware_utility = cost_aware_utility
383383
self.expand = expand
384384
self.project = project
385+
self._cost_sampler = None
386+
385387
# @TODO make sure fidelity_dims align in project, expand & cost_aware_utility
386388
# It seems very difficult due to the current way of handling project/expand
387389

388390
# resample max values after initializing self.project
389391
# so that the max value samples are at the highest fidelity
390392
self._sample_max_values()
391393

394+
@property
395+
def cost_sampler(self):
396+
if self._cost_sampler is None:
397+
# Note: Using the deepcopy here is essential. Removing this poses a
398+
# problem if the base model and the cost model have a different number
399+
# of outputs or test points (this would be caused by expand), as this
400+
# would trigger re-sampling the base samples in the fantasy sampler.
401+
# By cloning the sampler here, the right thing will happen if the
402+
# the sizes are compatible, if they are not this will result in
403+
# samples being drawn using different base samples, but it will at
404+
# least avoid changing state of the fantasy sampler.
405+
self._cost_sampler = deepcopy(self.fantasies_sampler)
406+
return self._cost_sampler
407+
392408
@t_batch_mode_transform(expected_q=1)
393409
def forward(self, X: Tensor) -> Tensor:
394410
r"""Evaluates `qMultifidelityMaxValueEntropy` at the design points `X`
@@ -422,8 +438,8 @@ def forward(self, X: Tensor) -> Tensor:
422438
ig = self._compute_information_gain(
423439
X=X_expand, mean_M=mean_M, variance_M=variance_M, covar_mM=covar_mM
424440
)
425-
426-
return self.cost_aware_utility(X, ig).mean(dim=0) # average over the fantasies
441+
ig = self.cost_aware_utility(X=X, deltas=ig, sampler=self.cost_sampler)
442+
return ig.mean(dim=0) # average over the fantasies
427443

428444

429445
def _sample_max_value_Thompson(

test/acquisition/test_max_value_entropy_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest import mock
99

1010
import torch
11+
from botorch.acquisition.cost_aware import InverseCostWeightedUtility
1112
from botorch.acquisition.max_value_entropy_search import (
1213
_sample_max_value_Gumbel,
1314
_sample_max_value_Thompson,
@@ -124,13 +125,17 @@ def test_q_multi_fidelity_max_value_entropy(self):
124125
self.assertEqual(qMF_MVE.num_fantasies, 16)
125126
self.assertEqual(qMF_MVE.num_mv_samples, 10)
126127
self.assertIsInstance(qMF_MVE.sampler, SobolQMCNormalSampler)
128+
self.assertIsInstance(qMF_MVE.cost_sampler, SobolQMCNormalSampler)
127129
self.assertEqual(qMF_MVE.sampler.sample_shape, torch.Size([128]))
128130
self.assertIsInstance(qMF_MVE.fantasies_sampler, SobolQMCNormalSampler)
129131
self.assertEqual(qMF_MVE.fantasies_sampler.sample_shape, torch.Size([16]))
130132
self.assertIsInstance(qMF_MVE.expand, Callable)
131133
self.assertIsInstance(qMF_MVE.project, Callable)
132134
self.assertIsNone(qMF_MVE.X_pending)
133135
self.assertEqual(qMF_MVE.posterior_max_values.shape, torch.Size([10, 1]))
136+
self.assertIsInstance(
137+
qMF_MVE.cost_aware_utility, InverseCostWeightedUtility
138+
)
134139

135140
# test evaluation
136141
X = torch.rand(1, 2, device=self.device, dtype=dtype)

0 commit comments

Comments
 (0)