@@ -3,29 +3,63 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
3
3
4
4
"""
5
5
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
6
-
7
6
[Softmax](https://en.wikipedia.org/wiki/Softmax_function) takes
8
7
log-probabilities (any real vector) and returns a probability distribution that
9
8
sums to 1.
10
-
11
9
If given a matrix it will treat it as a batch of vectors, with each column
12
10
independent.
13
-
14
11
julia> softmax([1,2,3.])
15
12
3-element Array{Float64,1}:
16
13
0.0900306
17
14
0.244728
18
15
0.665241
19
16
"""
17
+ softmax (xs) = softmax! (similar (xs), xs)
18
+
20
19
function softmax (xs:: AbstractArray{T} ; dims= 1 ) where {T}
21
20
max = maximum (xs, dims= dims)
22
21
out = exp .(xs .- max)
23
22
out = out ./ sum (out, dims= dims)
24
23
end
25
24
25
+ function softmax! (out:: AbstractVecOrMat{T} , xs:: AbstractVecOrMat{T} ) where {T}
26
+ @inbounds for j = 1 : size (xs, 2 )
27
+ # First, store column-wise maximum in the last element of `out`
28
+ out[end , j] = xs[end , j]
29
+ @inbounds for i = 1 : (size (xs, 1 ) - 1 )
30
+ out[end , j] = max (out[end , j], xs[i, j])
31
+ end
32
+
33
+ # Subtract the column-wise maximums to normalize, take exp()
34
+ # out .= exp(xs .- out[end, :])
35
+ @inbounds for i = 1 : size (out, 1 )
36
+ out[i, j] = exp (xs[i, j] - out[end , j])
37
+ end
38
+
39
+ # Normalize by sum of the entire thing
40
+ # out ./= sum(out, 1)
41
+ s = T (0 )
42
+ @inbounds for i = 1 : size (out, 1 )
43
+ s += out[i, j]
44
+ end
45
+ @inbounds for i = 1 : size (out, 1 )
46
+ out[i, j] /= s
47
+ end
48
+ end
49
+ return out
50
+ end
51
+
52
+ function ∇softmax! (out:: AbstractVecOrMat , Δ:: AbstractVecOrMat , xs:: AbstractVecOrMat )
53
+ sf = softmax (xs)
54
+ out .= sf .* (Δ .- sum (Δ .* sf, dims = 1 ))
55
+ end
56
+
57
+ ∇softmax (Δ, xs) = ∇softmax! (similar (Δ), Δ, xs)
58
+ ∇softmax! (Δ, xs) = ∇softmax! (Δ, Δ, xs)
59
+
60
+
26
61
"""
27
62
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
28
-
29
63
`logsoftmax(xs)` computes the log of `softmax(xs)`, but in a more numerically stable
30
64
way than directly taking the log of the softmax function, which is commonly used in
31
65
computing cross entropy loss.
0 commit comments