Skip to content

Commit e7c1611

Browse files
committed
made softmax and logsoftmax more stable with respect to overflow
1 parent b653dc1 commit e7c1611

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

src/logsoftmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323
logsoftmax!(xs) = logsoftmax!(xs, xs)
2424
logsoftmax(xs) = logsoftmax!(similar(xs), xs)
2525

26-
∇logsoftmax(Δ, xs) = ∇softmax./ softmax(xs), xs)
26+
∇logsoftmax(Δ, xs) = ∇softmax./ max.(eps(eltype(xs)),softmax(xs)), xs)
2727

2828
"""
2929
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))

src/softmax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ softmax!(xs) = softmax!(xs, xs)
2929
softmax(xs) = softmax!(similar(xs), xs)
3030

3131
function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
32-
s = sum(exp, xs, dims=1)
33-
out .= exp.(xs)./s.*.- sum.* exp.(xs), dims=1)./s)
32+
sf = softmax(xs)
33+
out .= sf .* .- sum.*sf, dims = 1))
3434
end
3535

3636
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,6 @@ xs = Float32[1, 2, 3000.]
2929

3030
xs = Float32[1 2 3; 1000 2000 3000]
3131
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
32+
@test NNlib.∇logsoftmax(ones(size(xs)), xs) zeros(Float32, size(xs))
33+
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))
3234
end

0 commit comments

Comments
 (0)