Skip to content

Commit 342928e

Browse files
authored
Merge pull request #130 from merckxiaan/master
Make softmax! dimension-agnostic
2 parents 4ea355e + 16262e8 commit 342928e

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/softmax.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ independent.
1717
0.244728
1818
0.665241
1919
"""
20-
softmax(xs) = softmax!(similar(xs), xs)
20+
function softmax(xs::AbstractArray{T}; dims=1) where {T}
21+
max = maximum(xs, dims=dims)
22+
out = exp.(xs .- max) ./ sum(exp.(xs .- max), dims=dims)
23+
end
2124

2225
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
2326
@inbounds for j = 1:size(xs, 2)
@@ -50,8 +53,10 @@ function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVe
5053
sf = softmax(xs)
5154
out .= sf .*.- sum.*sf, dims = 1))
5255
end
53-
54-
∇softmax(Δ, xs) = ∇softmax!(similar(Δ), Δ, xs)
56+
function ∇softmax(Δ, xs; dims=1)
57+
sf = softmax(xs, dims=dims)
58+
out = sf .*.- sum.* sf, dims=dims))
59+
end
5560
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)
5661

5762

0 commit comments

Comments
 (0)