Skip to content

Commit 16262e8

Browse files
committed
grad function for softmax
1 parent eec6d74 commit 16262e8

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/softmax.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVe
5353
sf = softmax(xs)
5454
out .= sf .*.- sum.*sf, dims = 1))
5555
end
56-
57-
∇softmax(Δ, xs) = ∇softmax!(similar(Δ), Δ, xs)
56+
function ∇softmax(Δ, xs; dims=1)
57+
sf = softmax(xs, dims=dims)
58+
out = sf .*.- sum.* sf, dims=dims))
59+
end
5860
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)
5961

6062

0 commit comments

Comments
 (0)