Skip to content

Commit 98ca791

Browse files
committed
use broadcasting
1 parent 9dee1cc commit 98ca791

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

src/softmax.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,9 @@ independent.
1919
"""
2020
softmax(xs; dims=1) = softmax!(similar(xs), xs, dims)
2121

22-
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}, dims) where {T}
23-
# First, store column-wise maximum in the last element of `out`
24-
maxdims = ntuple(d -> (d in dims) ? (1:1) : (:), ndims(xs))
25-
out[maxdims...] = maximum!(out[maxdims...], xs)
26-
27-
# Subtract the column-wise maximums to normalize, take exp()
28-
out .= exp.(xs.-out[maxdims...])
29-
30-
# Normalize by sum of the entire thing
22+
function softmax!(out::AbstractArray{T}, xs::AbstractArray{T}, dims) where {T}
23+
max = maximum(xs, dims)
24+
out .= exp.(xs.-max)
3125
out ./= sum(out, dims=dims)
3226
return out
3327
end

0 commit comments

Comments
 (0)