Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2070,6 +2070,74 @@ defmodule Axon do
)
end

@doc ~S"""
Adds an RMS normalization layer to the network.

RMS normalization normalizes the input tensor using only the root
mean square, without centering by the mean. This is computationally
simpler than layer normalization while achieving similar results.

See `Axon.Layers.rms_norm/3` for more details.

$$y = \frac{x}{\sqrt{E[x^2] + \epsilon}} * (\text{shift} + \gamma)$$

## Options

* `:name` - layer name.

* `:gamma_initializer` - gamma parameter initializer. Defaults
to `:ones`.

* `:channel_index` - input feature index used for calculating
the root mean square. Defaults to `-1`.

* `:epsilon` - numerical stability term. Defaults to `1.0e-6`.

* `:shift` - numeric shift added to gamma before scaling.
Defaults to `0.0`.

* `:upcast` - adds explicit type casting to make sure the norm
is computed in high numerical precision. Either of:

* `:normalization` (default) - upcasts only the input normalization
part

* `:all` - upcasts both input normalization and the scaling
expression

## References

* [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)

"""
@doc type: :normalization
def rms_norm(%Axon{} = x, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
:meta,
gamma_initializer: :ones,
channel_index: -1,
epsilon: 1.0e-6,
shift: 0.0,
upcast: :normalization
])

channel_index = opts[:channel_index]
gamma_shape = &Axon.Shape.norm_param(&1, channel_index)
gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer])

layer(:rms_norm, [x, gamma],
name: opts[:name],
meta: opts[:meta],
epsilon: opts[:epsilon],
channel_index: channel_index,
shift: opts[:shift],
upcast: opts[:upcast],
op_name: :rms_norm
)
end

@doc """
Applies the given `Nx` expression to the input.

Expand Down
95 changes: 95 additions & 0 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,101 @@ defmodule Axon.Layers do
Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.shape(input))
end

@doc ~S"""
Functional implementation of RMS normalization.

Normalizes the input by calculating the root mean square of the
input tensor along the given feature dimension `:channel_index`.
Unlike layer normalization, RMS normalization does not center the
input by subtracting the mean.

$$y = \frac{x}{\sqrt{E[x^2] + \epsilon}} * (\text{shift} + \gamma)$$

`gamma` is often a trainable parameter. This method does not maintain
an EMA of variance.

## Options

* `:epsilon` - numerical stability term. $\epsilon$ in the above
formulation. Defaults to `1.0e-6`.

* `:channel_index` - channel index used to determine reduction
axes for RMS calculation. Defaults to `-1`.

* `:shift` - numeric shift added to gamma before scaling.
Defaults to `0.0`.

* `:upcast` - controls type casting for numerical precision.
Either `:normalization` (default) to upcast only the normalization
part, or `:all` to upcast the entire computation.

## References

* [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)
"""
@doc type: :normalization
defn rms_norm(input, gamma, opts \\ []) do
opts =
keyword!(opts,
epsilon: 1.0e-6,
channel_index: -1,
shift: 0.0,
upcast: :normalization,
mode: :inference
)

rms_norm_impl(input, gamma, opts)
end

deftransformp rms_norm_impl(input, gamma, opts) do
case opts[:upcast] do
:normalization ->
rms_norm_upcast_normalization(input, gamma, opts)

:all ->
rms_norm_upcast_all(input, gamma, opts)

other ->
raise ArgumentError,
"expected :upcast to be either :all or :normalization, got: #{inspect(other)}"
end
end

defnp rms_norm_upcast_normalization(input, gamma, opts) do
num_channels = Nx.axis_size(input, opts[:channel_index])
parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index])
gamma = Nx.reshape(gamma, parameter_shape)

normalized_input =
input
|> Nx.as_type(:f32)
|> rms_normalize(opts)
|> Nx.as_type(Nx.type(input))

normalized_input * (opts[:shift] + gamma)
end

defnp rms_norm_upcast_all(input, gamma, opts) do
num_channels = Nx.axis_size(input, opts[:channel_index])
parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index])
gamma = Nx.reshape(gamma, parameter_shape)

input = Nx.as_type(input, :f32)
gamma = Nx.as_type(gamma, :f32)

normalized_input = rms_normalize(input, opts)
normalized_input * (opts[:shift] + gamma)
end

defnp rms_normalize(input, opts) do
variance =
input
|> Nx.pow(2)
|> Nx.mean(axes: [opts[:channel_index]], keep_axes: true)

input * Nx.rsqrt(variance + opts[:epsilon])
end

@doc ~S"""
Functional implementation of instance normalization.

Expand Down
92 changes: 87 additions & 5 deletions test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ defmodule Axon.LayersTest do
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
]
]),
atol: 1.0e-4
atol: 1.0e-3
)

assert_all_close(
Expand All @@ -1000,7 +1000,7 @@ defmodule Axon.LayersTest do
[[6.3724, 7.3724, 8.3724], [8.2449, 9.2449, 10.2449], [10.1173, 11.1173, 12.1173]]
]
]),
atol: 1.0e-4
atol: 1.0e-3
)

assert_all_close(
Expand All @@ -1012,7 +1012,7 @@ defmodule Axon.LayersTest do
[[6.4508, 7.4508, 8.4508], [8.4016, 9.4016, 10.4016], [10.3525, 11.3525, 12.3525]]
]
]),
atol: 1.0e-4
atol: 1.0e-3
)
end

Expand All @@ -1036,7 +1036,7 @@ defmodule Axon.LayersTest do
]
]
]),
atol: 1.0e-4
atol: 1.0e-3
)

# Downscaling (no effect)
Expand All @@ -1052,7 +1052,7 @@ defmodule Axon.LayersTest do
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
]
]),
atol: 1.0e-4
atol: 1.0e-3
)
end
end
Expand Down Expand Up @@ -1723,6 +1723,88 @@ defmodule Axon.LayersTest do
end
end

describe "rms_norm" do
test "matches pytorch 2D input" do
input =
Nx.tensor([
[1.9269, 1.4873, 0.9007, -2.1055, 0.6784, -1.2345, -0.0431, -1.6047],
[-0.7521, 1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688, 0.7624]
])

gamma =
Nx.tensor([0.4617, 0.2674, 0.5349, 0.8094, 1.1103, -1.6898, -0.9890, 0.9580])

expected =
Nx.tensor([
[0.6344, 0.2836, 0.3436, -1.2153, 0.5372, 1.4877, 0.0304, -1.0962],
[-0.3605, 0.4576, -0.2179, -1.1793, -0.8390, 0.9814, 0.7893, 0.7582]
])

actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1)
assert_all_close(expected, actual, atol: 1.0e-3)
end

test "matches pytorch 3D input" do
input =
Nx.tensor([
[
[-1.3847, -0.8712, -0.2234, 1.7174, 0.3189, -0.4245],
[0.3057, -0.7746, -1.5576, 0.9956, -0.8798, -0.6011],
[-1.2742, 2.1228, -1.2347, -0.4879, -0.9138, -0.6581],
[0.0780, 0.5258, -0.4880, 1.1914, -0.8140, -0.7360]
],
[
[-1.4032, 0.0360, -0.0635, 0.6756, -0.0978, 1.8446],
[-1.1845, 1.3835, 1.4451, 0.8564, 2.2181, 0.5232],
[0.3466, -0.1973, -1.0546, 1.2780, -0.1722, 0.5238],
[0.0566, 0.4263, 0.5750, -0.6417, -2.2064, -0.7508]
]
])

gamma =
Nx.tensor([0.4679, -0.2049, -0.7409, 0.3618, 1.9199, -0.2254])

expected =
Nx.tensor([
[
[-0.6502, 0.1792, 0.1661, 0.6236, 0.6144, 0.0960],
[0.1530, 0.1698, 1.2341, 0.3853, -1.8064, 0.1449],
[-0.4825, -0.3521, 0.7403, -0.1429, -1.4199, 0.1201],
[0.0504, -0.1488, 0.4994, 0.5955, -2.1588, 0.2291]
],
[
[-0.6653, -0.0075, 0.0477, 0.2477, -0.1903, -0.4213],
[-0.4033, -0.2063, -0.7791, 0.2255, 3.0986, -0.0858],
[0.2218, 0.0553, 1.0685, 0.6324, -0.4521, -0.1614],
[0.0257, -0.0849, -0.4138, -0.2255, -4.1147, 0.1644]
]
])

actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1)
assert_all_close(expected, actual, atol: 1.0e-3)
end

test "matches pytorch with ones weight" do
input =
Nx.tensor([
[0.6127, -1.1754, -0.7646, -0.6666],
[0.7444, -0.6453, -1.3890, -0.2730]
])

gamma =
Nx.tensor([1.0000, 1.0000, 1.0000, 1.0000])

expected =
Nx.tensor([
[0.7342, -1.4084, -0.9163, -0.7987],
[0.8632, -0.7483, -1.6108, -0.3165]
])

actual = Axon.Layers.rms_norm(input, gamma, epsilon: 1.0e-6, channel_index: -1)
assert_all_close(expected, actual, atol: 1.0e-3)
end
end

describe "batch_norm" do
test "matches pytorch when variance < epsilon" do
input_val = -0.002805
Expand Down
Loading