Skip to content

Commit e3d7f77

Browse files
James Wilsonfacebook-github-bot
authored andcommitted
Fix batch_shape handling in Normalize and InputStandardize (#1360)
Summary: Pull Request resolved: #1360 Follow up to #1078. Reviewed By: saitcakmak Differential Revision: D38881646 fbshipit-source-id: 50f127ccd7699e760058609e7ed0904568864ab7
1 parent 993ff39 commit e3d7f77

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

botorch/models/transforms/input.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def __init__(
341341
take all dimensions of the inputs into account.
342342
bounds: If provided, use these bounds to normalize the inputs. If
343343
omitted, learn the bounds in train mode.
344-
batch_shape: The batch shape of the inputs (asssuming input tensors
344+
batch_shape: The batch shape of the inputs (assuming input tensors
345345
of shape `batch_shape x n x d`). If provided, perform individual
346346
normalization per batch, otherwise uses a single normalization.
347347
transform_on_train: A boolean indicating whether to apply the
@@ -410,10 +410,27 @@ def _transform(self, X: Tensor) -> Tensor:
410410
f"Wrong input dimension. Received {X.size(-1)}, "
411411
f"expected {self.mins.size(-1)}."
412412
)
413-
self.mins = X.min(dim=-2, keepdim=True)[0]
414-
ranges = X.max(dim=-2, keepdim=True)[0] - self.mins
415-
ranges[torch.where(ranges <= self.min_range)] = self.min_range
416-
self.ranges = ranges
413+
414+
n = len(self.batch_shape) + 2
415+
if X.ndim < n:
416+
raise ValueError(
417+
f"`X` must have at least {n} dimensions, {n - 2} batch and 2 innate"
418+
f" , but has {X.ndim}."
419+
)
420+
421+
# Move extra batch and innate batch (i.e. marginal) dimensions to the right
422+
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
423+
_X = X.permute(
424+
*range(X.ndim - batch_ndim - 2, X.ndim - 2), # module batch dims
425+
X.ndim - 1, # input dim
426+
*range(X.ndim - batch_ndim - 2), # other dims, to be reduced over
427+
X.ndim - 2, # marginal dim
428+
).reshape(*self.batch_shape, 1, X.shape[-1], -1)
429+
430+
# Extract minimums and ranges
431+
self.mins = _X.min(dim=-1).values # batch_shape x (1, d)
432+
self.ranges = (_X.max(dim=-1).values - self.mins).clip(min=self.min_range)
433+
417434
if hasattr(self, "indices"):
418435
X_new = X.clone()
419436
X_new[..., self.indices] = (
@@ -551,10 +568,23 @@ def _transform(self, X: Tensor) -> Tensor:
551568
f"Wrong input. dimension. Received {X.size(-1)}, "
552569
f"expected {self.means.size(-1)}"
553570
)
554-
self.means = X.mean(dim=-2, keepdim=True)
555-
self.stds = X.std(dim=-2, keepdim=True)
556571

557-
self.stds = torch.clamp(self.stds, min=self.min_std)
572+
n = len(self.batch_shape) + 2
573+
if X.ndim < n:
574+
raise ValueError(
575+
f"`X` must have at least {n} dimensions, {n - 2} batch and 2 innate"
576+
f" , but has {X.ndim}."
577+
)
578+
579+
# Aggregate means and standard deviations over extra batch and marginal dims
580+
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
581+
reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2)
582+
self.stds, self.means = (
583+
values.unsqueeze(-2)
584+
for values in torch.std_mean(X, dim=reduce_dims, unbiased=True)
585+
)
586+
self.stds.clamp_(min=self.min_std)
587+
558588
if hasattr(self, "indices"):
559589
X_new = X.clone()
560590
X_new[..., self.indices] = (

test/models/transforms/test_input.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,22 @@ def test_normalize(self):
187187
X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1)
188188
X = X.to(self.device)
189189
self.assertEqual(torch.isfinite(nlz(X)).sum(), X.numel())
190+
with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"):
191+
nlz(torch.randn(X.shape[-1], dtype=dtype))
190192

191193
# basic usage
192194
for batch_shape in (torch.Size(), torch.Size([3])):
193195
# learned bounds
194196
nlz = Normalize(d=2, batch_shape=batch_shape)
195197
X = torch.randn(*batch_shape, 4, 2, device=self.device, dtype=dtype)
196-
X_nlzd = nlz(X)
198+
for _X in (torch.stack((X, X)), X): # check batch_shape is obeyed
199+
X_nlzd = nlz(_X)
200+
self.assertEqual(nlz.mins.shape, batch_shape + (1, X.shape[-1]))
201+
self.assertEqual(nlz.ranges.shape, batch_shape + (1, X.shape[-1]))
202+
197203
self.assertEqual(X_nlzd.min().item(), 0.0)
198204
self.assertEqual(X_nlzd.max().item(), 1.0)
205+
199206
nlz.eval()
200207
X_unnlzd = nlz.untransform(X_nlzd)
201208
self.assertTrue(torch.allclose(X, X_unnlzd, atol=1e-4, rtol=1e-4))
@@ -356,15 +363,22 @@ def test_standardize(self):
356363
X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1)
357364
X = X.to(self.device, dtype=dtype)
358365
self.assertEqual(torch.isfinite(stdz(X)).sum(), X.numel())
366+
with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"):
367+
stdz(torch.randn(X.shape[-1], dtype=dtype))
359368

360369
# basic usage
361370
for batch_shape in (torch.Size(), torch.Size([3])):
362371
stdz = InputStandardize(d=2, batch_shape=batch_shape)
363372
torch.manual_seed(42)
364373
X = torch.randn(*batch_shape, 4, 2, device=self.device, dtype=dtype)
365-
X_stdz = stdz(X)
374+
for _X in (torch.stack((X, X)), X): # check batch_shape is obeyed
375+
X_stdz = stdz(_X)
376+
self.assertEqual(stdz.means.shape, batch_shape + (1, X.shape[-1]))
377+
self.assertEqual(stdz.stds.shape, batch_shape + (1, X.shape[-1]))
378+
366379
self.assertTrue(torch.all(X_stdz.mean(dim=-2).abs() < 1e-4))
367380
self.assertTrue(torch.all((X_stdz.std(dim=-2) - 1.0).abs() < 1e-4))
381+
368382
stdz.eval()
369383
X_unstdz = stdz.untransform(X_stdz)
370384
self.assertTrue(torch.allclose(X, X_unstdz, atol=1e-4, rtol=1e-4))

0 commit comments

Comments
 (0)