@@ -13,12 +13,10 @@ class RMSNormalization(Layer):
1313 [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)
1414 by Biao Zhang et al.
1515
16+ The layer scales the normalized outputs via a learnable scaling factor
17+ (`scale`).
1618
17- If `scale` is enabled, the layer will scale the normalized outputs via
18- a learnable scaling factor.
19-
20- So, with scaling enabled, the normalization equations
21- are as follows:
19+ The normalization equations are as follows:
2220
2321 Let the intermediate activations for a mini-batch to be the `inputs`.
2422
@@ -39,7 +37,18 @@ class RMSNormalization(Layer):
3937
4038 Args:
4139 axis: int. The axis on which to perform the normalization.
40+ Typically, this is the features axis. `-1` is the last dimension
41+ in the input. Defaults to `-1`.
4242 epsilon: float. A small number to add to avoid division by zero.
43+ Defaults to `1e-6`.
44+
45+ Input shape:
46+ Arbitrary. Use the keyword argument `input_shape` (tuple of integers,
47+ does not include the samples axis) when using this layer as the first
48+ layer in a model.
49+
50+ Output shape:
51+ Same shape as input.
4352 """
4453
4554 def __init__ (self , axis = - 1 , epsilon = 1e-6 , ** kwargs ):
@@ -62,15 +71,6 @@ def build(self, input_shape):
6271 self .built = True
6372
6473 def call (self , x ):
65- """Applies RMS normalization to the input tensor.
66-
67- Args:
68- x: Input tensor of shape (batch_size, input_dim).
69-
70- Returns:
71- The RMS-normalized tensor of the same shape (batch_size, input_dim),
72- scaled by the learned `scale` parameter.
73- """
7474 return ops .rms_normalization (
7575 x , scale = self .scale , axis = self .axis , epsilon = self .epsilon
7676 )
0 commit comments