@@ -18,8 +18,9 @@ independent.
18
18
0.665241
19
19
"""
20
20
function softmax (xs:: AbstractArray{T} ; dims= 1 ) where {T}
21
- max = maximum (xs, dims= dims)
22
- out = exp .(xs .- max) ./ sum (exp .(xs .- max), dims= dims)
21
+ temp = maximum (xs, dims= dims)
22
+ out = exp .(xs .- temp)
23
+ out .= out ./ sum! (temp, out)
23
24
end
24
25
25
26
function softmax! (out:: AbstractVecOrMat{T} , xs:: AbstractVecOrMat{T} ) where {T}
51
52
52
53
function ∇softmax! (out:: AbstractVecOrMat , Δ:: AbstractVecOrMat , xs:: AbstractVecOrMat )
53
54
sf = softmax (xs)
54
- out .= sf .* (Δ .- sum (Δ .* sf, dims = 1 ))
55
+ out .= sf .* (Δ .- sum (Δ .* sf, dims = 1 ))
55
56
end
56
- function ∇softmax (Δ, xs; dims= 1 )
57
+ function ∇softmax (Δ, xs; dims= 1 )
57
58
sf = softmax (xs, dims= dims)
58
59
out = sf .* (Δ .- sum (Δ .* sf, dims= dims))
59
60
end
67
68
way than directly taking the log of the softmax function, which is commonly used in
68
69
computing cross entropy loss.
69
70
"""
70
- logsoftmax (xs) = logsoftmax! (similar (xs), xs)
71
+ function logsoftmax (xs:: AbstractArray{T} ; dims= 1 ) where {T}
72
+ max_ = maximum (xs, dims= dims)
73
+ out = exp .(xs .- max_)
74
+ log_ = log .(sum (out, dims= dims))
75
+ out .= (xs .- max_) .- log_
76
+ end
77
+
71
78
function logsoftmax! (out:: AbstractVecOrMat , xs:: AbstractVecOrMat )
72
79
for j = 1 : size (xs, 2 )
73
80
@inbounds begin
@@ -86,5 +93,5 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
86
93
end
87
94
return out
88
95
end
89
- ∇logsoftmax (Δ, xs) = Δ - sum (Δ, dims= 1 ) .* softmax (xs)
96
+ ∇logsoftmax (Δ, xs; dims = 1 ) = Δ . - sum (Δ, dims= dims ) .* softmax (xs)
90
97
∇logsoftmax! (Δ, xs) = ∇softmax! (Δ, Δ, xs)
0 commit comments