Skip to content

Commit 4cb3587

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fix test failure in test_evaluate_qMFKG (#618)
Summary: Caused by a PR race condition between #594 and #588 Pull Request resolved: #618 Reviewed By: qingfeng10 Differential Revision: D25386475 Pulled By: Balandat fbshipit-source-id: 95f445c5a03e07b2c35eab7134b43b39ec25c472
1 parent 1dca562 commit 4cb3587

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

botorch/utils/testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def num_outputs(self) -> int:
169169
event_shape = self._posterior.event_shape
170170
return event_shape[-1] if len(event_shape) > 0 else 0
171171

172+
@property
173+
def batch_shape(self) -> torch.Size:
174+
event_shape = self._posterior.event_shape
175+
return event_shape[:-2]
176+
172177
def state_dict(self) -> None:
173178
pass
174179

test/acquisition/test_knowledge_gradient.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,11 @@ def test_evaluate_qMFKG(self):
406406
for dtype in (torch.float, torch.double):
407407
mean = torch.zeros(1, 1, 1, device=self.device, dtype=dtype)
408408
mm = MockModel(MockPosterior(mean=mean))
409-
mm._input_batch_shape = torch.Size([1])
410409
cau = GenericCostAwareUtility(mock_util)
411410
n_f = 4
412411
mean = torch.rand(n_f, 2, 1, 1, device=self.device, dtype=dtype)
413412
variance = torch.rand(n_f, 2, 1, 1, device=self.device, dtype=dtype)
414413
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
415-
mfm._input_batch_shape = torch.Size([n_f, 2])
416414
with ExitStack() as es:
417415
patch_f = es.enter_context(
418416
mock.patch.object(MockModel, "fantasize", return_value=mfm)

0 commit comments

Comments
 (0)