Skip to content

Commit 5998b47

Browse files
authored
Merge pull request #135 from mcabbott/softmax
logsoftmax with dims
2 parents 342928e + 283496e commit 5998b47

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

src/softmax.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
88
log-probabilities (any real vector) and returns a probability distribution that
99
sums to 1.
1010
11-
If given a matrix it will treat it as a batch of vectors, with each column
12-
independent.
11+
If given a matrix it will by default (`dims=1`) treat it as a batch of vectors,
12+
with each column independent. Keyword `dims=2` will instead treat rows independently, etc.
1313
14-
julia> softmax([1,2,3.])
15-
3-element Array{Float64,1}:
16-
0.0900306
17-
0.244728
18-
0.665241
14+
```
15+
julia> softmax([1,2,3.])
16+
3-element Array{Float64,1}:
17+
0.0900306
18+
0.244728
19+
0.665241
20+
```
1921
"""
20-
function softmax(xs::AbstractArray{T}; dims=1) where {T}
21-
max = maximum(xs, dims=dims)
22-
out = exp.(xs .- max) ./ sum(exp.(xs .- max), dims=dims)
22+
function softmax(xs::AbstractArray; dims=1)
23+
max_ = maximum(xs, dims=dims)
24+
exp_ = exp.(xs .- max_)
25+
exp_ ./ sum(exp_, dims=dims)
2326
end
2427

2528
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
@@ -51,23 +54,29 @@ end
5154

5255
function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
5356
sf = softmax(xs)
54-
out .= sf .*.- sum.*sf, dims = 1))
57+
out .= sf .*.- sum.* sf, dims = 1))
5558
end
56-
function ∇softmax(Δ, xs; dims=1)
59+
function ∇softmax(Δ, xs; dims=1)
5760
sf = softmax(xs, dims=dims)
58-
out = sf .*.- sum.* sf, dims=dims))
61+
sf .*.- sum.* sf, dims=dims))
5962
end
6063
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)
6164

6265

6366
"""
6467
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
6568
66-
`logsoftmax(xs)` computes the log of `softmax(xs)`, but in a more numerically stable
67-
way than directly taking the log of the softmax function, which is commonly used in
69+
Computes the log of softmax in a more numerically stable
70+
way than directly taking `log.(softmax(xs))`. Commonly used in
6871
computing cross entropy loss.
6972
"""
70-
logsoftmax(xs) = logsoftmax!(similar(xs), xs)
73+
function logsoftmax(xs::AbstractArray; dims=1)
74+
max_ = maximum(xs, dims=dims)
75+
exp_ = exp.(xs .- max_)
76+
log_ = log.(sum(exp_, dims=dims))
77+
(xs .- max_) .- log_
78+
end
79+
7180
function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
7281
for j = 1:size(xs, 2)
7382
@inbounds begin
@@ -86,5 +95,6 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
8695
end
8796
return out
8897
end
89-
∇logsoftmax(Δ, xs) = Δ - sum(Δ, dims=1) .* softmax(xs)
98+
99+
∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs)
90100
∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)

test/activation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ end
8484
@testset "softmax" begin
8585
xs = rand(5,5)
8686
@test all(sum(softmax(xs), dims = 1) .≈ 1)
87+
@test all(sum(softmax(xs; dims=2), dims = 2) .≈ 1)
8788
@test sum(softmax(vec(xs))) 1
89+
@test log.(softmax(xs; dims=2)) logsoftmax(xs; dims=2)
8890

8991
xs = [-100_000, -100_000.]
9092
@test softmax(xs) [0.5, 0.5]
@@ -100,7 +102,7 @@ end
100102
xs = Float32[1 2 3; 1000 2000 3000]
101103
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
102104

103-
@test NNlib.∇logsoftmax(ones(size(xs)), xs) Float32[1 1 1; -1 -1 -1]
105+
@test NNlib.∇logsoftmax(ones(size(xs)), xs) Float32[1 1 1; -1 -1 -1]
104106
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))
105107

106108
# These values precalculated using PyTorch's nn.LogSoftmax

0 commit comments

Comments
 (0)