Skip to content

Commit 77a75a8

Browse files
author
xingzihai
committed
docs: Improve docstring for RMSNormalization layer
- Fixed misleading statement about 'scale' parameter (layer always uses scale) - Added default values for 'axis' and 'epsilon' parameters - Added more detailed description for 'axis' parameter - Added Input shape and Output shape sections - Removed redundant call method docstring for consistency with other normalization layers
1 parent 4244745 commit 77a75a8

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

keras/src/layers/normalization/rms_normalization.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@ class RMSNormalization(Layer):
1313
[Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)
1414
by Biao Zhang et al.
1515
16+
The layer scales the normalized outputs via a learnable scaling factor
17+
(`scale`).
1618
17-
If `scale` is enabled, the layer will scale the normalized outputs via
18-
a learnable scaling factor.
19-
20-
So, with scaling enabled, the normalization equations
21-
are as follows:
19+
The normalization equations are as follows:
2220
2321
Let the intermediate activations for a mini-batch to be the `inputs`.
2422
@@ -39,7 +37,18 @@ class RMSNormalization(Layer):
3937
4038
Args:
4139
axis: int. The axis on which to perform the normalization.
40+
Typically, this is the features axis. `-1` is the last dimension
41+
in the input. Defaults to `-1`.
4242
epsilon: float. A small number to add to avoid division by zero.
43+
Defaults to `1e-6`.
44+
45+
Input shape:
46+
Arbitrary. Use the keyword argument `input_shape` (tuple of integers,
47+
does not include the samples axis) when using this layer as the first
48+
layer in a model.
49+
50+
Output shape:
51+
Same shape as input.
4352
"""
4453

4554
def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
@@ -62,15 +71,6 @@ def build(self, input_shape):
6271
self.built = True
6372

6473
def call(self, x):
65-
"""Applies RMS normalization to the input tensor.
66-
67-
Args:
68-
x: Input tensor of shape (batch_size, input_dim).
69-
70-
Returns:
71-
The RMS-normalized tensor of the same shape (batch_size, input_dim),
72-
scaled by the learned `scale` parameter.
73-
"""
7474
return ops.rms_normalization(
7575
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
7676
)

0 commit comments

Comments
 (0)