Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions keras/src/layers/normalization/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ class RMSNormalization(Layer):
[Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)
by Biao Zhang et al.

The layer scales the normalized outputs via a learnable scaling factor
(`scale`).

If `scale` is enabled, the layer will scale the normalized outputs via
a learnable scaling factor.

So, with scaling enabled, the normalization equations
are as follows:
The normalization equations are as follows:

Let the intermediate activations for a mini-batch to be the `inputs`.

Expand All @@ -39,7 +37,18 @@ class RMSNormalization(Layer):

Args:
axis: int. The axis on which to perform the normalization.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for the axis argument is not entirely accurate. While it's often an integer, this layer also supports a list or tuple of integers for normalization over multiple axes, as shown in rms_normalization_test.py. To be more precise and align with the implementation, please update the type to reflect this.

Suggested change
axis: int. The axis on which to perform the normalization.
axis: int or list of ints. The axis or axes on which to perform the normalization.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address this comment.

Typically, this is the features axis. `-1` is the last dimension
in the input. Defaults to `-1`.
epsilon: float. A small number to add to avoid division by zero.
Defaults to `1e-6`.

Input shape:
Arbitrary. Use the keyword argument `input_shape` (tuple of integers,
does not include the samples axis) when using this layer as the first
layer in a model.

Output shape:
Same shape as input.
"""

def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
Expand All @@ -62,15 +71,6 @@ def build(self, input_shape):
self.built = True

def call(self, x):
"""Applies RMS normalization to the input tensor.

Args:
x: Input tensor of shape (batch_size, input_dim).

Returns:
The RMS-normalized tensor of the same shape (batch_size, input_dim),
scaled by the learned `scale` parameter.
"""
return ops.rms_normalization(
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
)
Expand Down
Loading