Skip to content

Commit 3b85e46

Browse files
committed
fix dimension arguments
1 parent 5de0867 commit 3b85e46

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/softmax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ independent.
1717
0.244728
1818
0.665241
1919
"""
20-
softmax(xs) = softmax!(similar(xs), xs)
20+
softmax(xs; dims=2) = softmax!(similar(xs), xs, dims)
2121

22-
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
22+
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}, dims) where {T}
2323
# First, store column-wise maximum in the last element of `out`
2424
maxdims = ntuple(d -> (d in dims) ? (1:1) : (:), ndims(xs))
2525
out[maxdims...] = maximum!(out[maxdims...], xs)

0 commit comments

Comments
 (0)