Skip to content

Commit 0ee6cff

Browse files
Balandatfacebook-github-bot
authored andcommitted
Changes to resampling behavior in MCSamplers (#204)
Summary: Makes the following changes in case `resample=False`: 1. if `collapse_batch_dims=True` and the requested shape is different from the shape of the base samples, only resample if the last (q and o) dimensions are different, otherwise broadcast the base samples to the requested batch shape (will just be a different number of ones) 2. if `dtype` or `device` of the posterior are different from that of the base samples don't resample if not triggered otherwise - instead automatically move the base samples to the correct device/dtype if apppropriate Pull Request resolved: #204 Reviewed By: danielrjiang Differential Revision: D16161405 Pulled By: Balandat fbshipit-source-id: a7465fd9c4287e0fb9716c23cd10d24e1516e54a
1 parent 08ae7ae commit 0ee6cff

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

botorch/sampling/samplers.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
8787
8888
- `resample=True`
8989
- the MCSampler has no `base_samples` attribute.
90-
- `shape` is different than `self.base_samples.shape`.
91-
- device and/or dtype of posterior are different than those of
92-
`self.base_samples`.
90+
- `shape` is different than `self.base_samples.shape` (if
91+
`collapse_batch_dims=True`, then batch dimensions of will be
92+
automatically broadcasted as necessary)
9393
9494
Args:
9595
posterior: The Posterior for which to generate base samples.
@@ -137,11 +137,12 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
137137
138138
This function will generate a new set of base samples and set the
139139
`base_samples` buffer if one of the following is true:
140-
- `resample=True`
141-
- the MCSampler has no `base_samples` attribute.
142-
- `shape` is different than `self.base_samples.shape`.
143-
- device and/or dtype of posterior ar different than those of
144-
`self.base_samples`.
140+
141+
- `resample=True`
142+
- the MCSampler has no `base_samples` attribute.
143+
- `shape` is different than `self.base_samples.shape` (if
144+
`collapse_batch_dims=True`, then batch dimensions of will be
145+
automatically broadcasted as necessary)
145146
146147
Args:
147148
posterior: The Posterior for which to generate base samples.
@@ -150,16 +151,21 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
150151
if (
151152
self.resample
152153
or not hasattr(self, "base_samples")
153-
or self.base_samples.shape != shape
154-
or self.base_samples.device != posterior.device
155-
or self.base_samples.dtype != posterior.dtype
154+
or self.base_samples.shape[-2:] != shape[-2:]
155+
or (not self.collapse_batch_dims and shape != self.base_samples.shape)
156156
):
157157
with manual_seed(seed=self.seed):
158158
base_samples = torch.randn(
159159
shape, device=posterior.device, dtype=posterior.dtype
160160
)
161161
self.seed += 1
162162
self.register_buffer("base_samples", base_samples)
163+
elif self.collapse_batch_dims and shape != self.base_samples.shape:
164+
self.base_samples = self.base_samples.view(shape)
165+
if self.base_samples.device != posterior.device:
166+
self.to(device=posterior.device) # pragma: nocover
167+
if self.base_samples.dtype != posterior.dtype:
168+
self.to(dtype=posterior.dtype)
163169

164170

165171
class SobolQMCNormalSampler(MCSampler):
@@ -201,11 +207,12 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
201207
202208
This function will generate a new set of base samples and set the
203209
`base_samples` buffer if one of the following is true:
204-
- `resample=True`
205-
- the MCSampler has no `base_samples` attribute.
206-
- `self.sample_shape` is different than `self.base_samples.shape`.
207-
- device and/or dtype of posterior ar different than those of
208-
`self.base_samples`.
210+
211+
- `resample=True`
212+
- the MCSampler has no `base_samples` attribute.
213+
- `shape` is different than `self.base_samples.shape` (if
214+
`collapse_batch_dims=True`, then batch dimensions of will be
215+
automatically broadcasted as necessary)
209216
210217
Args:
211218
posterior: The Posterior for which to generate base samples.
@@ -214,9 +221,8 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
214221
if (
215222
self.resample
216223
or not hasattr(self, "base_samples")
217-
or self.base_samples.shape != shape
218-
or self.base_samples.device != posterior.device
219-
or self.base_samples.dtype != posterior.dtype
224+
or self.base_samples.shape[-2:] != shape[-2:]
225+
or (not self.collapse_batch_dims and shape != self.base_samples.shape)
220226
):
221227
output_dim = shape[-2:].numel()
222228
if output_dim > SobolEngine.MAXDIM:
@@ -234,3 +240,9 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
234240
self.seed += 1
235241
base_samples = base_samples.view(shape)
236242
self.register_buffer("base_samples", base_samples)
243+
elif self.collapse_batch_dims and shape != posterior.event_shape:
244+
self.base_samples = self.base_samples.view(shape)
245+
if self.base_samples.device != posterior.device:
246+
self.to(device=posterior.device) # pragma: nocover
247+
if self.base_samples.dtype != posterior.dtype:
248+
self.to(dtype=posterior.dtype)

test/sampling/test_sampler.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,13 @@ def test_forward(self, cuda=False):
8484
posterior_batched = _get_posterior_batched(cuda=cuda, dtype=dtype)
8585
samples_batched = sampler(posterior_batched)
8686
self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
87-
self.assertEqual(sampler.seed, 1236)
87+
self.assertEqual(sampler.seed, 1235)
88+
# ensure this works when changing the dtype
89+
new_dtype = torch.float if dtype == torch.double else torch.double
90+
posterior_batched = _get_posterior_batched(cuda=cuda, dtype=new_dtype)
91+
samples_batched = sampler(posterior_batched)
92+
self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
93+
self.assertEqual(sampler.seed, 1235)
8894

8995
# resample
9096
sampler = IIDNormalSampler(num_samples=4, resample=True, seed=None)
@@ -212,7 +218,13 @@ def test_forward(self, cuda=False):
212218
posterior_batched = _get_posterior_batched(cuda=cuda, dtype=dtype)
213219
samples_batched = sampler(posterior_batched)
214220
self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
215-
self.assertEqual(sampler.seed, 1236)
221+
self.assertEqual(sampler.seed, 1235)
222+
# ensure this works when changing the dtype
223+
new_dtype = torch.float if dtype == torch.double else torch.double
224+
posterior_batched = _get_posterior_batched(cuda=cuda, dtype=new_dtype)
225+
samples_batched = sampler(posterior_batched)
226+
self.assertEqual(samples_batched.shape, torch.Size([4, 3, 2, 1]))
227+
self.assertEqual(sampler.seed, 1235)
216228

217229
# resample
218230
sampler = SobolQMCNormalSampler(num_samples=4, resample=True, seed=None)

0 commit comments

Comments
 (0)