Skip to content

Commit adc6cd0

Browse files
authored
fix: stop_grad for logsumexp (#1470)
1 parent bb056ce commit adc6cd0

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

nx/lib/nx.ex

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17332,6 +17332,16 @@ defmodule Nx do
1733217332
axes = opts[:axes]
1733317333
keep_axes = opts[:keep_axes]
1733417334
max = reduce_max(tensor, axes: axes, keep_axes: true)
17335+
17336+
max =
17337+
case max do
17338+
%T{data: %Nx.Defn.Expr{}} = t ->
17339+
Nx.Defn.Kernel.stop_grad(t)
17340+
17341+
t ->
17342+
t
17343+
end
17344+
1733517345
infinity_mask = is_infinity(max)
1733617346
max = select(infinity_mask, Nx.tensor(0, type: type), max)
1733717347
exponentials = tensor |> subtract(max) |> exp()

nx/test/nx/defn/expr_test.exs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,16 @@ defmodule Nx.Defn.ExprTest do
188188
named = Nx.tensor([4], names: [:dim])
189189
assert %T{type: {:f, 32}, names: [:dim]} = Nx.multiply(Expr.tensor(named), Expr.tensor(1.0))
190190
end
191+
192+
test "logsumexp" do
193+
expr = Nx.logsumexp(Expr.tensor(Nx.tensor([1, 2, 3, 4, 5, 6])))
194+
195+
assert inspect(expr) =~ """
196+
tensor a s64[6]
197+
b = reduce_max a, axes: [0], keep_axes: true s64[1]
198+
c = metadata b, :stop_grad s64[1]
199+
"""
200+
end
191201
end
192202

193203
describe "inspect" do

0 commit comments

Comments
 (0)