Skip to content

Commit 5de0867

Browse files
committed
make softmax! dimension-agnostic
1 parent 4ea355e commit 5de0867

File tree

1 file changed

+7
-21
lines changed

1 file changed

+7
-21
lines changed

src/softmax.jl

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,15 @@ independent.
2020
softmax(xs) = softmax!(similar(xs), xs)
2121

2222
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
23-
@inbounds for j = 1:size(xs, 2)
24-
# First, store column-wise maximum in the last element of `out`
25-
out[end, j] = xs[end, j]
26-
@inbounds for i = 1:(size(xs, 1) - 1)
27-
out[end, j] = max(out[end, j], xs[i, j])
28-
end
23+
# First, store column-wise maximum in the last element of `out`
24+
maxdims = ntuple(d -> (d in dims) ? (1:1) : (:), ndims(xs))
25+
out[maxdims...] = maximum!(out[maxdims...], xs)
2926

30-
# Subtract the column-wise maximums to normalize, take exp()
31-
# out .= exp(xs .- out[end, :])
32-
@inbounds for i = 1:size(out, 1)
33-
out[i, j] = exp(xs[i, j] - out[end, j])
34-
end
27+
# Subtract the column-wise maximums to normalize, take exp()
28+
out .= exp.(xs.-out[maxdims...])
3529

36-
# Normalize by sum of the entire thing
37-
# out ./= sum(out, 1)
38-
s = T(0)
39-
@inbounds for i = 1:size(out, 1)
40-
s += out[i, j]
41-
end
42-
@inbounds for i = 1:size(out, 1)
43-
out[i, j] /= s
44-
end
45-
end
30+
# Normalize by sum of the entire thing
31+
out ./= sum(out, dims=dims)
4632
return out
4733
end
4834

0 commit comments

Comments
 (0)