Skip to content

Commit 0773796

Browse files
committed
numerically stable logsoftmax gradient
1 parent f5fce7a commit 0773796

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/softmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,5 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
8181
end
8282
return out
8383
end
84-
∇logsoftmax(Δ, xs) = ∇softmax./ max.(eps(eltype(xs)),softmax(xs)), xs)
84+
∇logsoftmax(Δ, xs) = Δ - sum(Δ, dims=1) .* softmax(xs)
8585
∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)

0 commit comments

Comments
 (0)