@@ -20,29 +20,15 @@ independent.
20
20
softmax (xs) = softmax! (similar (xs), xs)
21
21
22
22
function softmax! (out:: AbstractVecOrMat{T} , xs:: AbstractVecOrMat{T} ) where {T}
23
- @inbounds for j = 1 : size (xs, 2 )
24
- # First, store column-wise maximum in the last element of `out`
25
- out[end , j] = xs[end , j]
26
- @inbounds for i = 1 : (size (xs, 1 ) - 1 )
27
- out[end , j] = max (out[end , j], xs[i, j])
28
- end
23
+ # First, store column-wise maximum in the last element of `out`
24
+ maxdims = ntuple (d -> (d in dims) ? (1 : 1 ) : (:), ndims (xs))
25
+ out[maxdims... ] = maximum! (out[maxdims... ], xs)
29
26
30
- # Subtract the column-wise maximums to normalize, take exp()
31
- # out .= exp(xs .- out[end, :])
32
- @inbounds for i = 1 : size (out, 1 )
33
- out[i, j] = exp (xs[i, j] - out[end , j])
34
- end
27
+ # Subtract the column-wise maximums to normalize, take exp()
28
+ out .= exp .(xs.- out[maxdims... ])
35
29
36
- # Normalize by sum of the entire thing
37
- # out ./= sum(out, 1)
38
- s = T (0 )
39
- @inbounds for i = 1 : size (out, 1 )
40
- s += out[i, j]
41
- end
42
- @inbounds for i = 1 : size (out, 1 )
43
- out[i, j] /= s
44
- end
45
- end
30
+ # Normalize by sum of the entire thing
31
+ out ./= sum (out, dims= dims)
46
32
return out
47
33
end
48
34
0 commit comments