@@ -8,19 +8,21 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
8
8
log-probabilities (any real vector) and returns a probability distribution that
9
9
sums to 1.
10
10
11
- If given a matrix it will treat it as a batch of vectors, with each column
12
- independent.
11
+ If given a matrix it will by default (`dims=1`) treat it as a batch of vectors,
12
+ with each column independent. Keyword `dims=2` will instead treat rows independently, etc .
13
13
14
- julia> softmax([1,2,3.])
15
- 3-element Array{Float64,1}:
16
- 0.0900306
17
- 0.244728
18
- 0.665241
14
+ ```
15
+ julia> softmax([1,2,3.])
16
+ 3-element Array{Float64,1}:
17
+ 0.0900306
18
+ 0.244728
19
+ 0.665241
20
+ ```
19
21
"""
20
- function softmax (xs:: AbstractArray{T} ; dims= 1 ) where {T}
21
- temp = maximum (xs, dims= dims)
22
- out = exp .(xs .- temp )
23
- out .= out ./ sum! (temp , out)
22
+ function softmax (xs:: AbstractArray ; dims= 1 )
23
+ max_ = maximum (xs, dims= dims)
24
+ out = exp .(xs .- max_ )
25
+ out .= out ./ sum! (max_ , out)
24
26
end
25
27
26
28
function softmax! (out:: AbstractVecOrMat{T} , xs:: AbstractVecOrMat{T} ) where {T}
64
66
"""
65
67
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
66
68
67
- `logsoftmax(xs)` computes the log of ` softmax(xs)`, but in a more numerically stable
68
- way than directly taking the log of the softmax function, which is commonly used in
69
+ Computes the log of softmax in a more numerically stable
70
+ way than directly taking ` log.( softmax(xs))`. Commonly used in
69
71
computing cross entropy loss.
70
72
"""
71
73
function logsoftmax (xs:: AbstractArray ; dims= 1 )
@@ -93,5 +95,6 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
93
95
end
94
96
return out
95
97
end
98
+
96
99
∇logsoftmax (Δ, xs; dims= 1 ) = Δ .- sum (Δ, dims= dims) .* softmax (xs)
97
100
∇logsoftmax! (Δ, xs) = ∇softmax! (Δ, Δ, xs)
0 commit comments