Skip to content

Commit 8102cd9

Browse files
authored
fix(grad): Nx.stack grad should remove the added axis (unbroadcast) (#1536)
1 parent 93e4383 commit 8102cd9

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

nx/lib/nx/defn/grad.ex

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ defmodule Nx.Defn.Grad do
2323
expr = to_grad |> fun.()
2424

2525
transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false)
26+
2627
{parents, nodes} = parents_tree(transformed_expr, ids)
2728

2829
to_grad_ids = {to_grad, ids}
@@ -623,7 +624,9 @@ defmodule Nx.Defn.Grad do
623624
current_limit = 1 + limit
624625
start = List.replace_at(zero_axes, axis, limit)
625626
len = List.replace_at(ans_shape_list, axis, 1)
626-
{{t, Nx.slice(g, start, len)}, current_limit}
627+
g = Nx.slice(g, start, len)
628+
g = Nx.squeeze(g, axes: [axis])
629+
{{t, g}, current_limit}
627630
end)
628631

629632
pairs

nx/test/nx/defn/grad_test.exs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,24 @@ defmodule Nx.Defn.GradTest do
17731773
end
17741774
end
17751775

1776+
describe "stack" do
1777+
test "works on compound functions for more than 1 axis" do
1778+
# This is a test that ensures that the added axis from the
1779+
# stack operation is correctly squeezed back out by
1780+
# the gradient computation.
1781+
x = 2.0
1782+
1783+
assert grad(Nx.tensor([[x]]), fn t ->
1784+
a = Nx.pow(t, 2)
1785+
b = Nx.pow(t, 3)
1786+
c = Nx.pow(t, 4)
1787+
1788+
Nx.stack([a, b, c], axis: 1)
1789+
|> Nx.sum()
1790+
end) == Nx.tensor([[2 * x + 3 * x ** 2 + 4 * x ** 3]])
1791+
end
1792+
end
1793+
17761794
describe "cholesky" do
17771795
defn cholesky_grad(t) do
17781796
grad(t, fn x -> x |> Nx.LinAlg.cholesky() |> Nx.sum() end)

0 commit comments

Comments
 (0)