@@ -8,18 +8,21 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
8
8
log-probabilities (any real vector) and returns a probability distribution that
9
9
sums to 1.
10
10
11
- If given a matrix it will treat it as a batch of vectors, with each column
12
- independent.
11
+ If given a matrix it will by default (`dims=1`) treat it as a batch of vectors,
12
+ with each column independent. Keyword `dims=2` will instead treat rows independently, etc .
13
13
14
- julia> softmax([1,2,3.])
15
- 3-element Array{Float64,1}:
16
- 0.0900306
17
- 0.244728
18
- 0.665241
14
+ ```
15
+ julia> softmax([1,2,3.])
16
+ 3-element Array{Float64,1}:
17
+ 0.0900306
18
+ 0.244728
19
+ 0.665241
20
+ ```
19
21
"""
20
- function softmax (xs:: AbstractArray{T} ; dims= 1 ) where {T}
21
- max = maximum (xs, dims= dims)
22
- out = exp .(xs .- max) ./ sum (exp .(xs .- max), dims= dims)
22
+ function softmax (xs:: AbstractArray ; dims= 1 )
23
+ max_ = maximum (xs, dims= dims)
24
+ exp_ = exp .(xs .- max_)
25
+ exp_ ./ sum (exp_, dims= dims)
23
26
end
24
27
25
28
function softmax! (out:: AbstractVecOrMat{T} , xs:: AbstractVecOrMat{T} ) where {T}
51
54
52
55
function ∇softmax! (out:: AbstractVecOrMat , Δ:: AbstractVecOrMat , xs:: AbstractVecOrMat )
53
56
sf = softmax (xs)
54
- out .= sf .* (Δ .- sum (Δ .* sf, dims = 1 ))
57
+ out .= sf .* (Δ .- sum (Δ .* sf, dims = 1 ))
55
58
end
56
- function ∇softmax (Δ, xs; dims= 1 )
59
+ function ∇softmax (Δ, xs; dims= 1 )
57
60
sf = softmax (xs, dims= dims)
58
- out = sf .* (Δ .- sum (Δ .* sf, dims= dims))
61
+ sf .* (Δ .- sum (Δ .* sf, dims= dims))
59
62
end
60
63
∇softmax! (Δ, xs) = ∇softmax! (Δ, Δ, xs)
61
64
62
65
63
66
"""
64
67
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
65
68
66
- `logsoftmax(xs)` computes the log of ` softmax(xs)`, but in a more numerically stable
67
- way than directly taking the log of the softmax function, which is commonly used in
69
+ Computes the log of softmax in a more numerically stable
70
+ way than directly taking ` log.( softmax(xs))`. Commonly used in
68
71
computing cross entropy loss.
69
72
"""
70
- logsoftmax (xs) = logsoftmax! (similar (xs), xs)
73
+ function logsoftmax (xs:: AbstractArray ; dims= 1 )
74
+ max_ = maximum (xs, dims= dims)
75
+ exp_ = exp .(xs .- max_)
76
+ log_ = log .(sum (exp_, dims= dims))
77
+ (xs .- max_) .- log_
78
+ end
79
+
71
80
function logsoftmax! (out:: AbstractVecOrMat , xs:: AbstractVecOrMat )
72
81
for j = 1 : size (xs, 2 )
73
82
@inbounds begin
@@ -86,5 +95,6 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
86
95
end
87
96
return out
88
97
end
89
- ∇logsoftmax (Δ, xs) = Δ - sum (Δ, dims= 1 ) .* softmax (xs)
98
+
99
+ ∇logsoftmax (Δ, xs; dims= 1 ) = Δ .- sum (Δ, dims= dims) .* softmax (xs)
90
100
∇logsoftmax! (Δ, xs) = ∇softmax! (Δ, Δ, xs)
0 commit comments