Skip to content

Commit b43fbe4

Browse files
committed
softmax + logsoftmax dims
1 parent 342928e commit b43fbe4

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

src/softmax.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ independent.
1818
0.665241
1919
"""
2020
function softmax(xs::AbstractArray{T}; dims=1) where {T}
21-
max = maximum(xs, dims=dims)
22-
out = exp.(xs .- max) ./ sum(exp.(xs .- max), dims=dims)
21+
temp = maximum(xs, dims=dims)
22+
out = exp.(xs .- temp)
23+
out .= out ./ sum!(temp, out)
2324
end
2425

2526
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
@@ -51,9 +52,9 @@ end
5152

5253
function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
5354
sf = softmax(xs)
54-
out .= sf .*.- sum.*sf, dims = 1))
55+
out .= sf .*.- sum.* sf, dims = 1))
5556
end
56-
function ∇softmax(Δ, xs; dims=1)
57+
function ∇softmax(Δ, xs; dims=1)
5758
sf = softmax(xs, dims=dims)
5859
out = sf .*.- sum.* sf, dims=dims))
5960
end
@@ -67,7 +68,13 @@ end
6768
way than directly taking the log of the softmax function, which is commonly used in
6869
computing cross entropy loss.
6970
"""
70-
logsoftmax(xs) = logsoftmax!(similar(xs), xs)
71+
function logsoftmax(xs::AbstractArray{T}; dims=1) where {T}
72+
max_ = maximum(xs, dims=dims)
73+
out = exp.(xs .- max_)
74+
log_ = log.(sum(out, dims=dims))
75+
out .= (xs .- max_) .- log_
76+
end
77+
7178
function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
7279
for j = 1:size(xs, 2)
7380
@inbounds begin
@@ -86,5 +93,5 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
8693
end
8794
return out
8895
end
89-
∇logsoftmax(Δ, xs) = Δ - sum(Δ, dims=1) .* softmax(xs)
96+
∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs)
9097
∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)

test/activation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ end
8484
@testset "softmax" begin
8585
xs = rand(5,5)
8686
@test all(sum(softmax(xs), dims = 1) .≈ 1)
87+
@test all(sum(softmax(xs; dims=2), dims = 2) .≈ 1)
8788
@test sum(softmax(vec(xs))) 1
89+
@test log.(softmax(xs; dims=2)) logsoftmax(xs; dims=2)
8890

8991
xs = [-100_000, -100_000.]
9092
@test softmax(xs) [0.5, 0.5]
@@ -100,7 +102,7 @@ end
100102
xs = Float32[1 2 3; 1000 2000 3000]
101103
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
102104

103-
@test NNlib.∇logsoftmax(ones(size(xs)), xs) Float32[1 1 1; -1 -1 -1]
105+
@test NNlib.∇logsoftmax(ones(size(xs)), xs) Float32[1 1 1; -1 -1 -1]
104106
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))
105107

106108
# These values precalculated using PyTorch's nn.LogSoftmax

0 commit comments

Comments
 (0)