Skip to content

Commit ce106cd

Browse files
authored
Merge pull request #1 from merckxiaan/softmax-derivative
make softmax out-of-place; remove gradient code
2 parents cfa3201 + 599856e commit ce106cd

File tree

1 file changed

+3
-15
lines changed

1 file changed

+3
-15
lines changed

src/softmax.jl

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,12 @@ independent.
1717
0.244728
1818
0.665241
1919
"""
20-
softmax(xs; dims=1) = softmax!(similar(xs), xs, dims)
21-
22-
function softmax!(out::AbstractArray{T}, xs::AbstractArray{T}, dims) where {T}
20+
function softmax(xs::AbstractArray{T}; dims=1) where {T}
2321
max = maximum(xs, dims=dims)
24-
out .= exp.(xs.-max)
25-
out ./= sum(out, dims=dims)
26-
return out
27-
end
28-
29-
function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
30-
sf = softmax(xs)
31-
out .= sf .*.- sum.*sf, dims = 1))
22+
out = exp.(xs .- max)
23+
out = out ./ sum(out, dims=dims)
3224
end
3325

34-
∇softmax(Δ, xs) = ∇softmax!(similar(Δ), Δ, xs)
35-
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)
36-
37-
3826
"""
3927
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
4028

0 commit comments

Comments
 (0)