Skip to content

Commit d846c97

Browse files
Fix logic error in normalize. (#622)
* Fix logic error in normalize. * Fix formatting.
1 parent 496fd2e commit d846c97

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

lib/axon/shared.ex

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,7 @@ defmodule Axon.Shared do
245245

246246
defn normalize(input, mean, variance, gamma, bias, opts \\ []) do
247247
[epsilon: epsilon] = keyword!(opts, epsilon: 1.0e-6)
248-
249-
# The select is so that we improve numerical stability by clipping
250-
# both insignificant values of variance and NaNs to epsilon.
251-
scale =
252-
gamma * Nx.select(variance >= epsilon, Nx.rsqrt(variance + epsilon), Nx.rsqrt(epsilon))
253-
248+
scale = gamma * Nx.rsqrt(variance + epsilon)
254249
scale * (input - mean) + bias
255250
end
256251

test/axon/layers_test.exs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,4 +1722,30 @@ defmodule Axon.LayersTest do
17221722
assert_all_close(expected, actual, atol: 1.0e-3)
17231723
end
17241724
end
1725+
1726+
describe "batch_norm" do
1727+
test "matches pytorch when variance < epsilon" do
1728+
input_val = -0.002805
1729+
mean = -0.008561
1730+
variance = 0.000412
1731+
weight = 1.0
1732+
bias = -0.144881
1733+
epsilon = 0.001
1734+
1735+
expected = Nx.tensor([0.0083])
1736+
1737+
actual =
1738+
Axon.Layers.batch_norm(
1739+
Nx.tensor([[[[input_val]]]]),
1740+
Nx.tensor([weight]),
1741+
Nx.tensor([bias]),
1742+
Nx.tensor([mean]),
1743+
Nx.tensor([variance]),
1744+
mode: :inference,
1745+
epsilon: epsilon
1746+
)
1747+
1748+
assert_all_close(expected, actual, atol: 1.0e-3)
1749+
end
1750+
end
17251751
end

0 commit comments

Comments
 (0)