Skip to content

Commit e1cb934

Browse files
Qing Fengfacebook-github-bot
authored andcommitted
modify PenalizedMCObjective to support non-batch eval (#2073)
Summary: Pull Request resolved: #2073 To match penalty term with MC objective, we current unsqueeze the first dim which corresponds to the dimension of MC samples. However, when a `qxd`-dim X tensor is evaluated e.g. computing feasibility, it causes shape mismatch. As one would expect `q`-dim tensor returned, it will return `1xq`-dim tensor instead. To fix, we check the dims of obj; if it is non-mc samples, we will sequeeze the first dim back. Reviewed By: bletham Differential Revision: D49305807 fbshipit-source-id: c8e2dedbdb38d79b5910aca393cf2d3d4e1e311f
1 parent d81a674 commit e1cb934

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

botorch/acquisition/penalized.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
366366
if self.expand_dim is not None:
367367
# reshape penalty_obj to match the dim
368368
penalty_obj = penalty_obj.unsqueeze(self.expand_dim)
369+
# this happens when samples is a `q x m`-dim tensor and X is a `q x d`-dim
370+
# tensor; obj returned from GenericMCObjective is a `q`-dim tensor and
371+
# penalty_obj is a `1 x q`-dim tensor.
372+
if obj.ndim == 1:
373+
assert penalty_obj.shape == torch.Size([1, samples.shape[-2]])
374+
penalty_obj = penalty_obj.squeeze(dim=0)
369375
return obj - self.regularization_parameter * penalty_obj
370376

371377

test/acquisition/test_penalized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def test_penalized_mc_objective(self):
291291
samples = torch.randn(4, 3, device=self.device, dtype=dtype)
292292
X = torch.randn(4, 5, device=self.device, dtype=dtype)
293293
penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X)
294-
self.assertTrue(torch.equal(obj(samples, X), penalized_obj))
294+
self.assertTrue(torch.equal(obj(samples, X), penalized_obj.squeeze(0)))
295295
# test 'q x d' Tensor X
296296
samples = torch.randn(4, 2, 3, device=self.device, dtype=dtype)
297297
X = torch.randn(2, 5, device=self.device, dtype=dtype)

0 commit comments

Comments
 (0)