Skip to content

Commit 20aa700

Browse files
authored
Merge pull request #75 from pevnak/softmax
Softmax
2 parents e896670 + 6fd9132 commit 20aa700

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

src/logsoftmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323
logsoftmax!(xs) = logsoftmax!(xs, xs)
2424
logsoftmax(xs) = logsoftmax!(similar(xs), xs)
2525

26-
∇logsoftmax(Δ, xs) = ∇softmax./ softmax(xs), xs)
26+
∇logsoftmax(Δ, xs) = ∇softmax./ max.(eps(eltype(xs)),softmax(xs)), xs)
2727

2828
"""
2929
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))

src/softmax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ softmax!(xs) = softmax!(xs, xs)
2929
softmax(xs) = softmax!(similar(xs), xs)
3030

3131
function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
32-
s = sum(exp, xs, dims=1)
33-
out .= exp.(xs)./s.*.- sum.* exp.(xs), dims=1)./s)
32+
sf = softmax(xs)
33+
out .= sf .* .- sum.*sf, dims = 1))
3434
end
3535

3636
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,11 @@ xs = Float32[1, 2, 3000.]
2929

3030
xs = Float32[1 2 3; 1000 2000 3000]
3131
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
32+
33+
@test NNlib.∇logsoftmax(ones(size(xs)), xs) zeros(Float32, size(xs))
34+
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))
35+
36+
xs = [-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842; 0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663; -1.14637 -0.577988 0.718952 0.91972 -0.620773 0.929977]
37+
@test isapprox(NNlib.∇logsoftmax(ones(size(xs)), xs), [0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273; -0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428; 0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697]; rtol = 1e-6)
38+
@test isapprox(NNlib.∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6)
3239
end

0 commit comments

Comments
 (0)