@@ -341,7 +341,7 @@ def __init__(
341341 take all dimensions of the inputs into account.
342342 bounds: If provided, use these bounds to normalize the inputs. If
343343 omitted, learn the bounds in train mode.
344- batch_shape: The batch shape of the inputs (asssuming input tensors
344+ batch_shape: The batch shape of the inputs (assuming input tensors
345345 of shape `batch_shape x n x d`). If provided, perform individual
346346 normalization per batch, otherwise uses a single normalization.
347347 transform_on_train: A boolean indicating whether to apply the
@@ -410,10 +410,27 @@ def _transform(self, X: Tensor) -> Tensor:
410410 f"Wrong input dimension. Received { X .size (- 1 )} , "
411411 f"expected { self .mins .size (- 1 )} ."
412412 )
413- self .mins = X .min (dim = - 2 , keepdim = True )[0 ]
414- ranges = X .max (dim = - 2 , keepdim = True )[0 ] - self .mins
415- ranges [torch .where (ranges <= self .min_range )] = self .min_range
416- self .ranges = ranges
413+
414+ n = len (self .batch_shape ) + 2
415+ if X .ndim < n :
416+ raise ValueError (
417+ f"`X` must have at least { n } dimensions, { n - 2 } batch and 2 innate"
418+ f" , but has { X .ndim } ."
419+ )
420+
421+ # Move extra batch and innate batch (i.e. marginal) dimensions to the right
422+ batch_ndim = min (len (self .batch_shape ), X .ndim - 2 ) # batch rank of `X`
423+ _X = X .permute (
424+ * range (X .ndim - batch_ndim - 2 , X .ndim - 2 ), # module batch dims
425+ X .ndim - 1 , # input dim
426+ * range (X .ndim - batch_ndim - 2 ), # other dims, to be reduced over
427+ X .ndim - 2 , # marginal dim
428+ ).reshape (* self .batch_shape , 1 , X .shape [- 1 ], - 1 )
429+
430+ # Extract minimums and ranges
431+ self .mins = _X .min (dim = - 1 ).values # batch_shape x (1, d)
432+ self .ranges = (_X .max (dim = - 1 ).values - self .mins ).clip (min = self .min_range )
433+
417434 if hasattr (self , "indices" ):
418435 X_new = X .clone ()
419436 X_new [..., self .indices ] = (
@@ -551,10 +568,23 @@ def _transform(self, X: Tensor) -> Tensor:
551568 f"Wrong input. dimension. Received { X .size (- 1 )} , "
552569 f"expected { self .means .size (- 1 )} "
553570 )
554- self .means = X .mean (dim = - 2 , keepdim = True )
555- self .stds = X .std (dim = - 2 , keepdim = True )
556571
557- self .stds = torch .clamp (self .stds , min = self .min_std )
572+ n = len (self .batch_shape ) + 2
573+ if X .ndim < n :
574+ raise ValueError (
575+ f"`X` must have at least { n } dimensions, { n - 2 } batch and 2 innate"
576+ f" , but has { X .ndim } ."
577+ )
578+
579+ # Aggregate means and standard deviations over extra batch and marginal dims
580+ batch_ndim = min (len (self .batch_shape ), X .ndim - 2 ) # batch rank of `X`
581+ reduce_dims = (* range (X .ndim - batch_ndim - 2 ), X .ndim - 2 )
582+ self .stds , self .means = (
583+ values .unsqueeze (- 2 )
584+ for values in torch .std_mean (X , dim = reduce_dims , unbiased = True )
585+ )
586+ self .stds .clamp_ (min = self .min_std )
587+
558588 if hasattr (self , "indices" ):
559589 X_new = X .clone ()
560590 X_new [..., self .indices ] = (
0 commit comments