diff --git a/lib/axon/shared.ex b/lib/axon/shared.ex index 87eff5ae..f31d847d 100644 --- a/lib/axon/shared.ex +++ b/lib/axon/shared.ex @@ -245,12 +245,7 @@ defmodule Axon.Shared do defn normalize(input, mean, variance, gamma, bias, opts \\ []) do [epsilon: epsilon] = keyword!(opts, epsilon: 1.0e-6) - - # The select is so that we improve numerical stability by clipping - # both insignificant values of variance and NaNs to epsilon. - scale = - gamma * Nx.select(variance >= epsilon, Nx.rsqrt(variance + epsilon), Nx.rsqrt(epsilon)) - + scale = gamma * Nx.rsqrt(variance + epsilon) scale * (input - mean) + bias end diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index 3637e89d..17a61631 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -1722,4 +1722,30 @@ defmodule Axon.LayersTest do assert_all_close(expected, actual, atol: 1.0e-3) end end + + describe "batch_norm" do + test "matches pytorch when variance < epsilon" do + input_val = -0.002805 + mean = -0.008561 + variance = 0.000412 + weight = 1.0 + bias = -0.144881 + epsilon = 0.001 + + expected = Nx.tensor([0.0083]) + + actual = + Axon.Layers.batch_norm( + Nx.tensor([[[[input_val]]]]), + Nx.tensor([weight]), + Nx.tensor([bias]), + Nx.tensor([mean]), + Nx.tensor([variance]), + mode: :inference, + epsilon: epsilon + ) + + assert_all_close(expected, actual, atol: 1.0e-3) + end + end end