Skip to content

Commit 32bdfda

Browse files
esantorellafacebook-github-bot
authored andcommitted
Add test case for fantasization with observation noise and empty data (#2389)
Summary: ## Motivation Working on getting BoTorch test coverage back up to 100%. - Add test for empty data (n=0) - Give ints names for clarity Pull Request resolved: #2389 Test Plan: Is a test Reviewed By: sdaulton Differential Revision: D58930023 Pulled By: esantorella fbshipit-source-id: ef9508feb86dd2387224c0d2cbd5d001848d9bc5
1 parent 42dfd09 commit 32bdfda

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

botorch/models/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,6 @@ def transform_inputs(
337337
# 'Self', but at this point the verbose 'T...' syntax is needed.
338338
def fantasize(
339339
self: TFantasizeMixin,
340-
# TODO: see if any of these can be imported only if TYPE_CHECKING
341340
X: Tensor,
342341
sampler: MCSampler,
343342
observation_noise: Optional[Tensor] = None,

test/models/test_gpytorch.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,23 @@ def test_gpytorch_model(self):
164164
# test subset_output
165165
with self.assertRaises(NotImplementedError):
166166
model.subset_output([0])
167+
167168
# test fantasize
168-
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
169-
cm = model.fantasize(torch.rand(2, 1, **tkwargs), sampler=sampler)
170-
self.assertIsInstance(cm, SimpleGPyTorchModel)
171-
self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))
172-
cm = model.fantasize(
173-
torch.rand(2, 1, **tkwargs),
174-
sampler=sampler,
175-
observation_noise=torch.rand(2, 1, **tkwargs),
176-
)
177-
self.assertIsInstance(cm, SimpleGPyTorchModel)
178-
self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))
169+
n_samps = 2
170+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([n_samps]))
171+
for n in [0, 2]:
172+
x = torch.rand(n, 1, **tkwargs)
173+
cm = model.fantasize(X=x, sampler=sampler)
174+
self.assertIsInstance(cm, SimpleGPyTorchModel)
175+
self.assertEqual(cm.train_targets.shape, torch.Size([n_samps, 5 + n]))
176+
cm = model.fantasize(
177+
X=x,
178+
sampler=sampler,
179+
observation_noise=torch.rand(n, 1, **tkwargs),
180+
)
181+
self.assertIsInstance(cm, SimpleGPyTorchModel)
182+
self.assertEqual(cm.train_targets.shape, torch.Size([n_samps, 5 + n]))
183+
179184
# test that boolean observation noise is deprecated
180185
msg = "`fantasize` no longer accepts a boolean for `observation_noise`."
181186
with self.assertRaisesRegex(DeprecationError, msg):

0 commit comments

Comments
 (0)