Skip to content

Commit 2d10d77

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Support passing a list of Tensors and floats to FixedFeatureAcquisitionFunction. (#836)
Summary: Pull Request resolved: #836 This diff introduces an API change. Previously, one would have to instantiate a FFACQF using 1) a full Tensor or 2) a list of floats. Here, we allow passing Tensor in the list as long as they are broadcastable. Reviewed By: Balandat Differential Revision: D28962037 fbshipit-source-id: 526031857f969f1a9348c83a9cf1bbb47eef90f2
1 parent 53074d8 commit 2d10d77

File tree

2 files changed

+98
-33
lines changed

2 files changed

+98
-33
lines changed

botorch/acquisition/fixed_feature.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import List, Union
14+
from numbers import Number
15+
from typing import List, Sequence, Union
1516

1617
import torch
1718
from botorch.acquisition.acquisition import AcquisitionFunction
@@ -36,7 +37,7 @@ def __init__(
3637
acq_function: AcquisitionFunction,
3738
d: int,
3839
columns: List[int],
39-
values: Union[Tensor, List[float]],
40+
values: Union[Tensor, Sequence[Union[Tensor, float]]],
4041
) -> None:
4142
r"""Derived Acquisition Function by fixing a subset of input features.
4243
@@ -51,16 +52,47 @@ def __init__(
5152
different for each of the `q` input points), or an array-like of
5253
values that is broadcastable to the input across `t`-batch and
5354
`q`-batch dimensions, e.g. a list of length `d_f` if values
54-
are the same across all `t` and `q`-batch dimensions.
55+
are the same across all `t` and `q`-batch dimensions, or a
56+
combination of `Tensor`s and numbers which can be broadcasted
57+
to form a tensor with trailing dimension size of `d_f`.
5558
"""
5659
Module.__init__(self)
5760
self.acq_func = acq_function
5861
self.d = d
59-
values = torch.as_tensor(values).detach().clone()
60-
self.register_buffer("values", values)
62+
if isinstance(values, Tensor):
63+
new_values = values.detach().clone()
64+
else:
65+
new_values = []
66+
for value in values:
67+
if isinstance(value, Number):
68+
new_values.append(torch.tensor([float(value)]))
69+
else:
70+
new_values.append(value.detach().clone())
71+
72+
# There are 3 cases for when `values` is a `Sequence`.
73+
# 1) `values` == list of floats as earlier.
74+
# 2) `values` == combination of floats and `Tensor`s.
75+
# 3) `values` == a list of `Tensor`s.
76+
# For 1), the below step creates a vector of length `len(values)`
77+
# For 2), the below step creates a `Tensor` of shape `batch_shape x q x d_f`
78+
# with the broadcasting functionality.
79+
# For 3), this is simply a concatenation, yielding a `Tensor` with the
80+
# same shape as in 2).
81+
# The key difference arises when `_construct_X_full` is invoked.
82+
# In 1), the expansion (`self.values.expand`) will expand the `Tensor` to
83+
# size `batch_shape x q x d_f`.
84+
# In 2) and 3), this expansion is a no-op because they are already of the
85+
# required size. However, 2) and 3) _cannot_ support varying `batch_shape`,
86+
# which means that all calls to `FixedFeatureAcquisitionFunction` have
87+
# to have the same size throughout when `values` contains a `Tensor`.
88+
# This is consistent with the scenario when a singular `Tensor` is passed
89+
# as the `values` argument.
90+
new_values = torch.cat(torch.broadcast_tensors(*new_values), dim=-1)
91+
92+
self.register_buffer("values", new_values)
6193
# build selector for _construct_X_full
6294
self._selector = []
63-
idx_X, idx_f = 0, d - values.shape[-1]
95+
idx_X, idx_f = 0, d - new_values.shape[-1]
6496
for i in range(self.d):
6597
if i in columns:
6698
self._selector.append(idx_f)

test/acquisition/test_fixed_feature.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,52 @@ def test_fixed_features(self):
1717
train_Y = train_X.norm(dim=-1, keepdim=True)
1818
model = SingleTaskGP(train_X, train_Y).to(device=self.device).eval()
1919
qEI = qExpectedImprovement(model, best_f=0.0)
20-
# test single point
21-
test_X = torch.rand(1, 3, device=self.device)
22-
qEI_ff = FixedFeatureAcquisitionFunction(
23-
qEI, d=3, columns=[2], values=test_X[..., -1:]
24-
)
25-
qei = qEI(test_X)
26-
qei_ff = qEI_ff(test_X[..., :-1])
27-
self.assertTrue(torch.allclose(qei, qei_ff))
28-
# test list input
29-
qEI_ff = FixedFeatureAcquisitionFunction(qEI, d=3, columns=[2], values=[0.5])
30-
qei_ff = qEI_ff(test_X[..., :-1])
31-
# test q-batch
32-
test_X = torch.rand(2, 3, device=self.device)
33-
qEI_ff = FixedFeatureAcquisitionFunction(
34-
qEI, d=3, columns=[1], values=test_X[..., [1]]
35-
)
36-
qei = qEI(test_X)
37-
qei_ff = qEI_ff(test_X[..., [0, 2]])
38-
self.assertTrue(torch.allclose(qei, qei_ff))
39-
# test t-batch with broadcasting
40-
test_X = torch.rand(2, 3, device=self.device).expand(4, 2, 3)
41-
qEI_ff = FixedFeatureAcquisitionFunction(
42-
qEI, d=3, columns=[2], values=test_X[0, :, -1:]
43-
)
44-
qei = qEI(test_X)
45-
qei_ff = qEI_ff(test_X[..., :-1])
46-
self.assertTrue(torch.allclose(qei, qei_ff))
20+
for q in [1, 2]:
21+
# test single point
22+
test_X = torch.rand(q, 3, device=self.device)
23+
qEI_ff = FixedFeatureAcquisitionFunction(
24+
qEI, d=3, columns=[2], values=test_X[..., -1:]
25+
)
26+
qei = qEI(test_X)
27+
qei_ff = qEI_ff(test_X[..., :-1])
28+
self.assertTrue(torch.allclose(qei, qei_ff))
29+
30+
# test list input with float
31+
qEI_ff = FixedFeatureAcquisitionFunction(
32+
qEI, d=3, columns=[2], values=[0.5]
33+
)
34+
qei_ff = qEI_ff(test_X[..., :-1])
35+
test_X_clone = test_X.clone()
36+
test_X_clone[..., 2] = 0.5
37+
qei = qEI(test_X_clone)
38+
self.assertTrue(torch.allclose(qei, qei_ff))
39+
40+
# test list input with Tensor and float
41+
qEI_ff = FixedFeatureAcquisitionFunction(
42+
qEI, d=3, columns=[0, 2], values=[test_X[..., [0]], 0.5]
43+
)
44+
qei_ff = qEI_ff(test_X[..., [1]])
45+
self.assertTrue(torch.allclose(qei, qei_ff))
46+
47+
# test t-batch with broadcasting and list of floats
48+
test_X = torch.rand(q, 3, device=self.device).expand(4, q, 3)
49+
qEI_ff = FixedFeatureAcquisitionFunction(
50+
qEI, d=3, columns=[2], values=test_X[0, :, -1:]
51+
)
52+
qei = qEI(test_X)
53+
qei_ff = qEI_ff(test_X[..., :-1])
54+
self.assertTrue(torch.allclose(qei, qei_ff))
55+
56+
# test t-batch with broadcasting and list of floats and Tensor
57+
qEI_ff = FixedFeatureAcquisitionFunction(
58+
qEI, d=3, columns=[0, 2], values=[test_X[0, :, [0]], 0.5]
59+
)
60+
test_X_clone = test_X.clone()
61+
test_X_clone[..., 2] = 0.5
62+
qei = qEI(test_X_clone)
63+
qei_ff = qEI_ff(test_X[..., [1]])
64+
self.assertTrue(torch.allclose(qei, qei_ff))
65+
4766
# test gradient
4867
test_X = torch.rand(1, 3, device=self.device, requires_grad=True)
4968
test_X_ff = test_X[..., :-1].detach().clone().requires_grad_(True)
@@ -56,6 +75,20 @@ def test_fixed_features(self):
5675
qei.backward()
5776
qei_ff.backward()
5877
self.assertTrue(torch.allclose(test_X.grad[..., :-1], test_X_ff.grad))
78+
79+
test_X = test_X.detach().clone()
80+
test_X_ff = test_X[..., [1]].detach().clone().requires_grad_(True)
81+
test_X[..., 2] = 0.5
82+
test_X.requires_grad_(True)
83+
qei = qEI(test_X)
84+
qEI_ff = FixedFeatureAcquisitionFunction(
85+
qEI, d=3, columns=[0, 2], values=[test_X[..., [0]].detach(), 0.5]
86+
)
87+
qei_ff = qEI_ff(test_X_ff)
88+
qei.backward()
89+
qei_ff.backward()
90+
self.assertTrue(torch.allclose(test_X.grad[..., [1]], test_X_ff.grad))
91+
5992
# test error b/c of incompatible input shapes
6093
with self.assertRaises(ValueError):
6194
qEI_ff(test_X)

0 commit comments

Comments
 (0)