Skip to content

Commit 3289f12

Browse files
bdeonovicdevmotion
andauthored
Update softmax with parameterization (#16)
Co-authored-by: David Widmann <[email protected]>
1 parent 30b8bbf commit 3289f12

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

src/basicfuns.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,19 @@ That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1.
232232
233233
See the [Wikipedia entry](https://en.wikipedia.org/wiki/Softmax_function)
234234
"""
235-
function softmax!(r::AbstractArray{R}, x::AbstractArray{T}) where {R<:AbstractFloat,T<:Real}
235+
function softmax!(r::AbstractArray{<:Real}, x::AbstractArray{<:Real})
236236
n = length(x)
237237
length(r) == n || throw(DimensionMismatch("Inconsistent array lengths."))
238238
u = maximum(x)
239-
s = 0.
239+
s = zero(eltype(r))
240240
@inbounds for i = 1:n
241241
s += (r[i] = exp(x[i] - u))
242242
end
243-
invs = convert(R, inv(s))
243+
invs = inv(s)
244244
@inbounds for i = 1:n
245245
r[i] *= invs
246246
end
247-
r
247+
return r
248248
end
249249

250250
"""
@@ -261,4 +261,4 @@ $(SIGNATURES)
261261
Return the [`softmax transformation`](https://en.wikipedia.org/wiki/Softmax_function)
262262
applied to `x`.
263263
"""
264-
softmax(x::AbstractArray{<:Real}) = softmax!(similar(x, Float64), x)
264+
softmax(x::AbstractArray{<:Real}) = softmax!(similar(x, float(eltype(x))), x)

test/basicfuns.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,22 @@ end
149149
@testset "softmax" begin
150150
x = [1.0, 2.0, 3.0]
151151
r = exp.(x) ./ sum(exp.(x))
152-
@test softmax([1.0, 2.0, 3.0]) r
152+
@test softmax(x) r
153153
softmax!(x)
154154
@test x r
155+
156+
x = [1, 2, 3]
157+
r = exp.(x) ./ sum(exp.(x))
158+
@test softmax(x) r
159+
@test eltype(softmax(x)) == Float64
160+
161+
x = [1//2, 2//3, 3//4]
162+
r = exp.(x) ./ sum(exp.(x))
163+
@test softmax(x) r
164+
@test eltype(softmax(x)) == Float64
165+
166+
x = Float32[1, 2, 3]
167+
r = exp.(x) ./ sum(exp.(x))
168+
@test softmax(x) r
169+
@test eltype(softmax(x)) == Float32
155170
end

0 commit comments

Comments
 (0)