Skip to content

Commit 671f155

Browse files
committed
revert things and add simple broadcast version
1 parent ce106cd commit 671f155

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

src/softmax.jl

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,63 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
33

44
"""
55
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
6-
76
[Softmax](https://en.wikipedia.org/wiki/Softmax_function) takes
87
log-probabilities (any real vector) and returns a probability distribution that
98
sums to 1.
10-
119
If given a matrix it will treat it as a batch of vectors, with each column
1210
independent.
13-
1411
julia> softmax([1,2,3.])
1512
3-element Array{Float64,1}:
1613
0.0900306
1714
0.244728
1815
0.665241
1916
"""
17+
softmax(xs) = softmax!(similar(xs), xs)
18+
2019
function softmax(xs::AbstractArray{T}; dims=1) where {T}
2120
max = maximum(xs, dims=dims)
2221
out = exp.(xs .- max)
2322
out = out ./ sum(out, dims=dims)
2423
end
2524

25+
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
26+
@inbounds for j = 1:size(xs, 2)
27+
# First, store column-wise maximum in the last element of `out`
28+
out[end, j] = xs[end, j]
29+
@inbounds for i = 1:(size(xs, 1) - 1)
30+
out[end, j] = max(out[end, j], xs[i, j])
31+
end
32+
33+
# Subtract the column-wise maximums to normalize, take exp()
34+
# out .= exp(xs .- out[end, :])
35+
@inbounds for i = 1:size(out, 1)
36+
out[i, j] = exp(xs[i, j] - out[end, j])
37+
end
38+
39+
# Normalize by sum of the entire thing
40+
# out ./= sum(out, 1)
41+
s = T(0)
42+
@inbounds for i = 1:size(out, 1)
43+
s += out[i, j]
44+
end
45+
@inbounds for i = 1:size(out, 1)
46+
out[i, j] /= s
47+
end
48+
end
49+
return out
50+
end
51+
52+
function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
53+
sf = softmax(xs)
54+
out .= sf .*.- sum.*sf, dims = 1))
55+
end
56+
57+
∇softmax(Δ, xs) = ∇softmax!(similar(Δ), Δ, xs)
58+
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)
59+
60+
2661
"""
2762
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
28-
2963
`logsoftmax(xs)` computes the log of `softmax(xs)`, but in a more numerically stable
3064
way than directly taking the log of the softmax function, which is commonly used in
3165
computing cross entropy loss.

0 commit comments

Comments
 (0)