Skip to content

Commit 9e5de81

Browse files
devmotionnalimilan
andauthored
Allow general indices in softmax! (#17)
Co-authored-by: Milan Bouchet-Valat <[email protected]>
1 parent 8f65aef commit 9e5de81

File tree

5 files changed

+41
-26
lines changed

5 files changed

+41
-26
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89

910
[compat]
10-
julia = "1"
1111
DocStringExtensions = "0.8"
12+
julia = "1"
1213

1314
[extras]
15+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517

1618
[targets]
17-
test = ["Test"]
19+
test = ["OffsetArrays", "Test"]

src/LogExpFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module LogExpFunctions
22

33
using DocStringExtensions: SIGNATURES
44
using Base: Math.@horner, @irrational
5+
import LinearAlgebra
56

67
export loghalf, logtwo, logπ, log2π, log4π
78
export xlogx, xlogy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,

src/basicfuns.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,17 +233,12 @@ That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1.
233233
See the [Wikipedia entry](https://en.wikipedia.org/wiki/Softmax_function)
234234
"""
235235
function softmax!(r::AbstractArray{<:Real}, x::AbstractArray{<:Real})
236-
n = length(x)
237-
length(r) == n || throw(DimensionMismatch("Inconsistent array lengths."))
236+
length(r) == length(x) || throw(DimensionMismatch("inconsistent array lengths"))
238237
u = maximum(x)
239-
s = zero(eltype(r))
240-
@inbounds for i = 1:n
241-
s += (r[i] = exp(x[i] - u))
242-
end
243-
invs = inv(s)
244-
@inbounds for i = 1:n
245-
r[i] *= invs
238+
map!(r, x) do xi
239+
return exp(xi - u)
246240
end
241+
LinearAlgebra.lmul!(inv(sum(r)), r)
247242
return r
248243
end
249244

test/basicfuns.jl

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,22 +149,37 @@ end
149149
@testset "softmax" begin
150150
x = [1.0, 2.0, 3.0]
151151
r = exp.(x) ./ sum(exp.(x))
152-
@test softmax(x) r
152+
153+
# in-place versions
154+
for T in (Float32, Float64)
155+
s = Vector{T}(undef, 3)
156+
softmax!(s, x)
157+
@test s r
158+
159+
s = Matrix{T}(undef, 1, 3)
160+
softmax!(s, x)
161+
@test s permutedims(r)
162+
end
153163
softmax!(x)
154164
@test x r
155-
156-
x = [1, 2, 3]
157-
r = exp.(x) ./ sum(exp.(x))
158-
@test softmax(x) r
159-
@test eltype(softmax(x)) == Float64
160-
165+
166+
for (S, T) in ((Int, Float64), (Float64, Float64), (Float32, Float32))
167+
x = S[1, 2, 3]
168+
s = softmax(x)
169+
@test s r
170+
@test eltype(s) === T
171+
end
172+
161173
x = [1//2, 2//3, 3//4]
162174
r = exp.(x) ./ sum(exp.(x))
163-
@test softmax(x) r
164-
@test eltype(softmax(x)) == Float64
175+
s = softmax(x)
176+
@test s r
177+
@test eltype(s) === Float64
165178

166-
x = Float32[1, 2, 3]
167-
r = exp.(x) ./ sum(exp.(x))
168-
@test softmax(x) r
169-
@test eltype(softmax(x)) == Float32
179+
# non-standard indices: #12
180+
x = OffsetArray(1:3, -2:0)
181+
s = softmax(x)
182+
@test s isa OffsetArray{Float64}
183+
@test axes(s, 1) == OffsetArrays.IdOffsetRange(-2:0)
184+
@test collect(s) softmax(1:3)
170185
end

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
using LogExpFunctions, Test
1+
using LogExpFunctions
2+
using OffsetArrays
3+
using Test
24

35
include("basicfuns.jl")

0 commit comments

Comments
 (0)