Skip to content

Commit 8677b1a

Browse files
committed
tidy
1 parent a9642c1 commit 8677b1a

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/softmax.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,21 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
88
log-probabilities (any real vector) and returns a probability distribution that
99
sums to 1.
1010
11-
If given a matrix it will treat it as a batch of vectors, with each column
12-
independent.
11+
If given a matrix it will by default (`dims=1`) treat it as a batch of vectors,
12+
with each column independent. Keyword `dims=2` will instead treat rows independently, etc.
1313
14-
julia> softmax([1,2,3.])
15-
3-element Array{Float64,1}:
16-
0.0900306
17-
0.244728
18-
0.665241
14+
```
15+
julia> softmax([1,2,3.])
16+
3-element Array{Float64,1}:
17+
0.0900306
18+
0.244728
19+
0.665241
20+
```
1921
"""
20-
function softmax(xs::AbstractArray{T}; dims=1) where {T}
21-
temp = maximum(xs, dims=dims)
22-
out = exp.(xs .- temp)
23-
out .= out ./ sum!(temp, out)
22+
function softmax(xs::AbstractArray; dims=1)
23+
max_ = maximum(xs, dims=dims)
24+
out = exp.(xs .- max_)
25+
out .= out ./ sum!(max_, out)
2426
end
2527

2628
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
@@ -64,8 +66,8 @@ end
6466
"""
6567
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
6668
67-
`logsoftmax(xs)` computes the log of `softmax(xs)`, but in a more numerically stable
68-
way than directly taking the log of the softmax function, which is commonly used in
69+
Computes the log of softmax in a more numerically stable
70+
way than directly taking `log.(softmax(xs))`. Commonly used in
6971
computing cross entropy loss.
7072
"""
7173
function logsoftmax(xs::AbstractArray; dims=1)
@@ -93,5 +95,6 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
9395
end
9496
return out
9597
end
98+
9699
∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs)
97100
∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)

0 commit comments

Comments
 (0)