Skip to content

Commit a80bdff

Browse files
authored
Merge pull request #126 from ornithos/logsoftmaxgrad
Improve numerical stability of logsoftmax gradient
2 parents f5fce7a + d219cdb commit a80bdff

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/softmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,5 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
8181
end
8282
return out
8383
end
84-
∇logsoftmax(Δ, xs) = ∇softmax./ max.(eps(eltype(xs)),softmax(xs)), xs)
84+
∇logsoftmax(Δ, xs) = Δ - sum(Δ, dims=1) .* softmax(xs)
8585
∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)

test/activation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ end
100100
xs = Float32[1 2 3; 1000 2000 3000]
101101
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
102102

103-
@test NNlib.∇logsoftmax(ones(size(xs)), xs) zeros(Float32, size(xs))
103+
@test NNlib.∇logsoftmax(ones(size(xs)), xs) Float32[1 1 1; -1 -1 -1]
104104
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))
105105

106106
# These values precalculated using PyTorch's nn.LogSoftmax

0 commit comments

Comments
 (0)