Skip to content

Commit 958a6d4

Browse files
committed
fixed ∇logsoftmax! implementation
1 parent 8c3e994 commit 958a6d4

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/softmax.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
33

44
"""
55
softmax(x; dims=1)
6-
7-
[Softmax](https://en.wikipedia.org/wiki/Softmax_function) turns input array `x`
6+
7+
[Softmax](https://en.wikipedia.org/wiki/Softmax_function) turns input array `x`
88
into probability distributions that sum to 1 along the dimensions specified by `dims`.
99
It is semantically equivalent to the following:
1010
@@ -13,7 +13,7 @@ It is semantically equivalent to the following:
1313
with additional manipulations enhancing numerical stability.
1414
1515
For a matrix input `x` it will by default (`dims=1`) treat it as a batch of vectors,
16-
with each column independent. Keyword `dims=2` will instead treat rows independently,
16+
with each column independent. Keyword `dims=2` will instead treat rows independently,
1717
etc...
1818
```julia-repl
1919
julia> softmax([1, 2, 3])
@@ -108,5 +108,9 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
108108
return out
109109
end
110110

111+
function ∇logsoftmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
112+
out .= Δ .- sum(Δ, dims=1) .* softmax(xs, dims=1)
113+
end
114+
111115
∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
112-
∇logsoftmax!(Δ, xs) = softmax!(Δ, Δ, xs)
116+
∇logsoftmax!(Δ, xs) = logsoftmax!(Δ, Δ, xs)

test/activation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,14 @@ end
265265
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
266266
out = zeros(Float64, size(xs))
267267
NNlib.∇logsoftmax!(out, xs)
268-
@test isapprox(out, NNlib.softmax(zeros(size(xs)), xs); rtol=1e-6)
268+
@test isapprox(out, NNlib.logsoftmax(zeros(size(xs)), xs); rtol=1e-6)
269269

270270
out = ones(Float64, size(xs))
271271
NNlib.∇softmax!(out, xs)
272272
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
273273
out = ones(Float64, size(xs))
274274
NNlib.∇logsoftmax!(out, xs)
275-
@test isapprox(out, NNlib.softmax(ones(size(xs)), xs); rtol=1e-6)
275+
@test isapprox(out, NNlib.logsoftmax(ones(size(xs)), xs); rtol=1e-6)
276276

277277
xs = [
278278
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842;
@@ -297,14 +297,14 @@ end
297297
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
298298
out = zeros(Float64, size(xs))
299299
NNlib.∇logsoftmax!(out, xs)
300-
@test isapprox(out, NNlib.softmax(zeros(size(xs)), xs); rtol=1e-6)
300+
@test isapprox(out, NNlib.logsoftmax(zeros(size(xs)), xs); rtol=1e-6)
301301

302302
out = ones(Float64, size(xs))
303303
NNlib.∇softmax!(out, xs)
304304
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
305305
out = ones(Float64, size(xs))
306306
NNlib.∇logsoftmax!(out, xs)
307-
@test isapprox(out, NNlib.softmax(ones(size(xs)), xs); rtol=1e-6)
307+
@test isapprox(out, NNlib.logsoftmax(ones(size(xs)), xs); rtol=1e-6)
308308
end
309309

310310
end

0 commit comments

Comments
 (0)