diff --git a/keras/src/layers/normalization/rms_normalization.py b/keras/src/layers/normalization/rms_normalization.py index f97ee6c3685d..f769c3c5359b 100644 --- a/keras/src/layers/normalization/rms_normalization.py +++ b/keras/src/layers/normalization/rms_normalization.py @@ -13,12 +13,10 @@ class RMSNormalization(Layer): [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467) by Biao Zhang et al. + The layer scales the normalized outputs via a learnable scaling factor + (`scale`). - If `scale` is enabled, the layer will scale the normalized outputs via - a learnable scaling factor. - - So, with scaling enabled, the normalization equations - are as follows: + The normalization equations are as follows: Let the intermediate activations for a mini-batch to be the `inputs`. @@ -39,7 +37,18 @@ class RMSNormalization(Layer): Args: axis: int. The axis on which to perform the normalization. + Typically, this is the features axis. `-1` is the last dimension + in the input. Defaults to `-1`. epsilon: float. A small number to add to avoid division by zero. + Defaults to `1e-6`. + + Input shape: + Arbitrary. Use the keyword argument `input_shape` (tuple of integers, + does not include the samples axis) when using this layer as the first + layer in a model. + + Output shape: + Same shape as input. """ def __init__(self, axis=-1, epsilon=1e-6, **kwargs): @@ -62,15 +71,6 @@ def build(self, input_shape): self.built = True def call(self, x): - """Applies RMS normalization to the input tensor. - - Args: - x: Input tensor of shape (batch_size, input_dim). - - Returns: - The RMS-normalized tensor of the same shape (batch_size, input_dim), - scaled by the learned `scale` parameter. - """ return ops.rms_normalization( x, scale=self.scale, axis=self.axis, epsilon=self.epsilon )