Skip to content

Commit 439c9ef

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Modifies qMFKG.evaluate() to work with project, expand and cost aware utility (#594)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/master/CONTRIBUTING.md --> ## Motivation Modifies `qMFKG.evaluate()` to work with `project`, `expand` and `cost_aware_utility`. Partially fixes #587. - Introduces a `ProjectedValueFunction` that wraps the `value_function` and applies the `project` operator on the `forward` call. - Changes `evaluate()` signature to use `X` instead of `X_actual`. Current implementation raises an exception with the decorators when called with `evaluate(X_actual=...)`. Note: The treatment of `cost_aware_utility` assumes that it is monotone non-decreasing in `deltas`. Otherwise, optimizing the inner problem and passing through `cost_aware_utility` may not produce the correct output. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/master/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #594 Test Plan: Added mock unit tests. Verified the expected behavior in additional offline tests. Reviewed By: qingfeng10 Differential Revision: D25173613 Pulled By: Balandat fbshipit-source-id: 3ba0f196a622a84c951fdc3526a53cb6905e85d2
1 parent 519b18b commit 439c9ef

File tree

4 files changed

+153
-15
lines changed

4 files changed

+153
-15
lines changed

botorch/acquisition/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
InverseCostWeightedUtility,
2424
)
2525
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
26-
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
26+
from botorch.acquisition.knowledge_gradient import (
27+
qKnowledgeGradient,
28+
qMultiFidelityKnowledgeGradient,
29+
)
2730
from botorch.acquisition.max_value_entropy_search import (
2831
qMaxValueEntropy,
2932
qMultiFidelityMaxValueEntropy,
@@ -62,6 +65,7 @@
6265
"UpperConfidenceBound",
6366
"qExpectedImprovement",
6467
"qKnowledgeGradient",
68+
"qMultiFidelityKnowledgeGradient",
6569
"qMaxValueEntropy",
6670
"qMultiFidelityMaxValueEntropy",
6771
"qNoisyExpectedImprovement",

botorch/acquisition/knowledge_gradient.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,12 @@ def forward(self, X: Tensor) -> Tensor:
189189

190190
@concatenate_pending_points
191191
@t_batch_mode_transform()
192-
def evaluate(self, X_actual: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
192+
def evaluate(self, X: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
193193
r"""Evaluate qKnowledgeGradient on the candidate set `X_actual` by
194194
solving the inner optimization problem.
195195
196196
Args:
197-
X_actual: A `b x q x d` Tensor with `b` t-batches of `q` design points
197+
X: A `b x q x d` Tensor with `b` t-batches of `q` design points
198198
each. Unlike `forward()`, this does not include solutions of the
199199
inner optimization problem.
200200
bounds: A `2 x d` tensor of lower and upper bounds for each column of
@@ -206,18 +206,24 @@ def evaluate(self, X_actual: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
206206
207207
Returns:
208208
A Tensor of shape `b`. For t-batch b, the q-KG value of the design
209-
`X_actual[b]` is averaged across the fantasy models.
209+
`X[b]` is averaged across the fantasy models.
210210
NOTE: If `current_value` is not provided, then this is not the
211-
true KG value of `X_actual[b]`.
211+
true KG value of `X[b]`.
212212
"""
213+
if hasattr(self, "expand"):
214+
X = self.expand(X)
215+
213216
# construct the fantasy model of shape `num_fantasies x b`
214217
fantasy_model = self.model.fantasize(
215-
X=X_actual, sampler=self.sampler, observation_noise=True
218+
X=X, sampler=self.sampler, observation_noise=True
216219
)
217220

218221
# get the value function
219222
value_function = _get_value_function(
220-
model=fantasy_model, objective=self.objective, sampler=self.inner_sampler
223+
model=fantasy_model,
224+
objective=self.objective,
225+
sampler=self.inner_sampler,
226+
project=getattr(self, "project", None),
221227
)
222228

223229
from botorch.generation.gen import gen_candidates_scipy
@@ -246,6 +252,10 @@ def evaluate(self, X_actual: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
246252
if self.current_value is not None:
247253
values = values - self.current_value
248254

255+
if hasattr(self, "cost_aware_utility"):
256+
values = self.cost_aware_utility(
257+
X=X, deltas=values, sampler=self.cost_sampler
258+
)
249259
# return average over the fantasy samples
250260
return values.mean(dim=0)
251261

@@ -409,13 +419,16 @@ def forward(self, X: Tensor) -> Tensor:
409419

410420
# get the value function
411421
value_function = _get_value_function(
412-
model=fantasy_model, objective=self.objective, sampler=self.inner_sampler
422+
model=fantasy_model,
423+
objective=self.objective,
424+
sampler=self.inner_sampler,
425+
project=self.project,
413426
)
414427

415428
# make sure to propagate gradients to the fantasy model train inputs
416429
# project the fantasy points
417430
with settings.propagate_grads(True):
418-
values = value_function(X=self.project(X_fantasies)) # num_fantasies x b
431+
values = value_function(X=X_fantasies) # num_fantasies x b
419432

420433
if self.current_value is not None:
421434
values = values - self.current_value
@@ -429,16 +442,47 @@ def forward(self, X: Tensor) -> Tensor:
429442
return values.mean(dim=0)
430443

431444

445+
class ProjectedAcquisitionFunction(AcquisitionFunction):
446+
r"""
447+
Defines a wrapper around an `AcquisitionFunction` that incorporates the project
448+
operator. Typically used to handle value functions in look-ahead methods.
449+
"""
450+
451+
def __init__(
452+
self,
453+
base_value_function: AcquisitionFunction,
454+
project: Callable[[Tensor], Tensor],
455+
) -> None:
456+
super().__init__(base_value_function.model)
457+
self.base_value_function = base_value_function
458+
self.project = project
459+
self.objective = base_value_function.objective
460+
self.sampler = getattr(base_value_function, "sampler", None)
461+
462+
def forward(self, X: Tensor) -> Tensor:
463+
return self.base_value_function(self.project(X))
464+
465+
432466
def _get_value_function(
433467
model: Model,
434468
objective: Optional[Union[MCAcquisitionObjective, ScalarizedObjective]] = None,
435469
sampler: Optional[MCSampler] = None,
470+
project: Optional[Callable[[Tensor], Tensor]] = None,
436471
) -> AcquisitionFunction:
437472
r"""Construct value function (i.e. inner acquisition function)."""
438473
if isinstance(objective, MCAcquisitionObjective):
439-
return qSimpleRegret(model=model, sampler=sampler, objective=objective)
474+
base_value_function = qSimpleRegret(
475+
model=model, sampler=sampler, objective=objective
476+
)
440477
else:
441-
return PosteriorMean(model=model, objective=objective)
478+
base_value_function = PosteriorMean(model=model, objective=objective)
479+
if project is None:
480+
return base_value_function
481+
else:
482+
return ProjectedAcquisitionFunction(
483+
base_value_function=base_value_function,
484+
project=project,
485+
)
442486

443487

444488
def _split_fantasy_points(X: Tensor, n_f: int) -> Tuple[Tensor, Tensor]:

botorch/optim/initializers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
_get_value_function,
1717
qKnowledgeGradient,
1818
)
19-
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
2019
from botorch.acquisition.utils import is_nonnegative
2120
from botorch.exceptions.warnings import BadInitialCandidatesWarning, SamplingWarning
2221
from botorch.models.model import Model
@@ -210,6 +209,7 @@ def gen_one_shot_kg_initial_conditions(
210209
model=acq_function.model,
211210
objective=acq_function.objective,
212211
sampler=acq_function.inner_sampler,
212+
project=getattr(acq_function, "project", None),
213213
)
214214
from botorch.optim.optimize import optimize_acqf
215215

@@ -304,9 +304,8 @@ def gen_value_function_initial_conditions(
304304
value_function = _get_value_function(
305305
model=current_model,
306306
objective=acq_function.objective,
307-
sampler=acq_function.sampler
308-
if isinstance(acq_function, MCAcquisitionFunction)
309-
else None,
307+
sampler=getattr(acq_function, "sampler", None),
308+
project=getattr(acq_function, "project", None),
310309
)
311310
from botorch.optim.optimize import optimize_acqf
312311

test/acquisition/test_knowledge_gradient.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from contextlib import ExitStack
78
from unittest import mock
89

910
import torch
@@ -14,6 +15,7 @@
1415
_split_fantasy_points,
1516
qKnowledgeGradient,
1617
qMultiFidelityKnowledgeGradient,
18+
ProjectedAcquisitionFunction,
1719
)
1820
from botorch.acquisition.monte_carlo import qSimpleRegret
1921
from botorch.acquisition.objective import GenericMCObjective, ScalarizedObjective
@@ -399,6 +401,72 @@ def test_evaluate_q_multi_fidelity_knowledge_gradient(self):
399401
self.assertTrue(torch.allclose(val, val_exp, atol=1e-4))
400402
self.assertTrue(torch.equal(qMFKG.extract_candidates(X), X[..., :-n_f, :]))
401403

404+
def test_evaluate_qMFKG(self):
405+
# mock test qMFKG.evaluate() with expand, project & cost aware utility
406+
for dtype in (torch.float, torch.double):
407+
mean = torch.zeros(1, 1, 1, device=self.device, dtype=dtype)
408+
mm = MockModel(MockPosterior(mean=mean))
409+
mm._input_batch_shape = torch.Size([1])
410+
cau = GenericCostAwareUtility(mock_util)
411+
n_f = 4
412+
mean = torch.rand(n_f, 2, 1, 1, device=self.device, dtype=dtype)
413+
variance = torch.rand(n_f, 2, 1, 1, device=self.device, dtype=dtype)
414+
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
415+
mfm._input_batch_shape = torch.Size([n_f, 2])
416+
with ExitStack() as es:
417+
patch_f = es.enter_context(
418+
mock.patch.object(MockModel, "fantasize", return_value=mfm)
419+
)
420+
mock_num_outputs = es.enter_context(
421+
mock.patch(NO, new_callable=mock.PropertyMock)
422+
)
423+
es.enter_context(
424+
mock.patch(
425+
"botorch.optim.optimize.optimize_acqf",
426+
return_value=(
427+
torch.ones(1, 1, 1, device=self.device, dtype=dtype),
428+
torch.ones(1, device=self.device, dtype=dtype),
429+
),
430+
),
431+
)
432+
es.enter_context(
433+
mock.patch(
434+
"botorch.generation.gen.gen_candidates_scipy",
435+
return_value=(
436+
torch.ones(1, 1, 1, device=self.device, dtype=dtype),
437+
torch.ones(1, device=self.device, dtype=dtype),
438+
),
439+
),
440+
)
441+
442+
mock_num_outputs.return_value = 1
443+
qMFKG = qMultiFidelityKnowledgeGradient(
444+
model=mm,
445+
num_fantasies=n_f,
446+
X_pending=torch.rand(1, 1, 1, device=self.device, dtype=dtype),
447+
current_value=torch.zeros(1, device=self.device, dtype=dtype),
448+
cost_aware_utility=cau,
449+
project=lambda X: torch.zeros_like(X),
450+
expand=lambda X: torch.ones_like(X),
451+
)
452+
val = qMFKG.evaluate(
453+
X=torch.zeros(1, 1, 1, device=self.device, dtype=dtype),
454+
bounds=torch.tensor([[0.0], [1.0]]),
455+
num_restarts=1,
456+
raw_samples=1,
457+
)
458+
patch_f.asset_called_once()
459+
cargs, ckwargs = patch_f.call_args
460+
self.assertTrue(
461+
torch.equal(
462+
ckwargs["X"],
463+
torch.ones(1, 2, 1, device=self.device, dtype=dtype),
464+
)
465+
)
466+
self.assertEqual(
467+
val, cau(None, torch.ones(1, device=self.device, dtype=dtype))
468+
)
469+
402470

403471
class TestKGUtils(BotorchTestCase):
404472
def test_get_value_function(self):
@@ -416,6 +484,29 @@ def test_get_value_function(self):
416484
self.assertIsInstance(vf, qSimpleRegret)
417485
self.assertEqual(vf.objective, obj)
418486
self.assertEqual(vf.sampler, sampler)
487+
# test with project
488+
mock_project = mock.Mock(
489+
return_value=torch.ones(1, 1, 1, device=self.device)
490+
)
491+
vf = _get_value_function(
492+
model=mm,
493+
objective=obj,
494+
sampler=sampler,
495+
project=mock_project,
496+
)
497+
self.assertIsInstance(vf, ProjectedAcquisitionFunction)
498+
self.assertEqual(vf.objective, obj)
499+
self.assertEqual(vf.sampler, sampler)
500+
self.assertEqual(vf.project, mock_project)
501+
test_X = torch.rand(1, 1, 1, device=self.device)
502+
with mock.patch.object(
503+
vf, "base_value_function", __class__=torch.nn.Module, return_value=None
504+
) as patch_bvf:
505+
vf(test_X)
506+
mock_project.assert_called_once_with(test_X)
507+
patch_bvf.assert_called_once_with(
508+
torch.ones(1, 1, 1, device=self.device)
509+
)
419510

420511
def test_split_fantasy_points(self):
421512
for dtype in (torch.float, torch.double):

0 commit comments

Comments
 (0)