We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9dee1cc commit 98ca791Copy full SHA for 98ca791
src/softmax.jl
@@ -19,15 +19,9 @@ independent.
19
"""
20
softmax(xs; dims=1) = softmax!(similar(xs), xs, dims)
21
22
-function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}, dims) where {T}
23
- # First, store column-wise maximum in the last element of `out`
24
- maxdims = ntuple(d -> (d in dims) ? (1:1) : (:), ndims(xs))
25
- out[maxdims...] = maximum!(out[maxdims...], xs)
26
-
27
- # Subtract the column-wise maximums to normalize, take exp()
28
- out .= exp.(xs.-out[maxdims...])
29
30
- # Normalize by sum of the entire thing
+function softmax!(out::AbstractArray{T}, xs::AbstractArray{T}, dims) where {T}
+ max = maximum(xs, dims)
+ out .= exp.(xs.-max)
31
out ./= sum(out, dims=dims)
32
return out
33
end
0 commit comments