Skip to content

Commit d9dc1c9

Browse files
authored
Fix logsumexp! with output arrays of abstract eltype (#40)
1 parent 239b2ad commit d9dc1c9

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogExpFunctions"
22
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
33
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"]
4-
version = "0.3.9"
4+
version = "0.3.10"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/logsumexp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ See also [`logsumexp`](@ref).
4242
4343
[Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
4444
"""
45-
function logsumexp!(out::AbstractArray{<:Number}, X::AbstractArray{<:Number})
46-
FT = eltype(out)
45+
function logsumexp!(out::AbstractArray, X::AbstractArray{<:Number})
46+
FT = float(eltype(X))
4747
xmax_r = fill!(similar(out, Tuple{FT,FT}), (FT(-Inf), zero(FT)))
4848
Base.reducedim!(_logsumexp_onepass_op, xmax_r, X)
4949
return @. out = first(xmax_r) + log1p(last(xmax_r))

test/basicfuns.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,13 @@ end
298298
@test @inferred(logsumexp(xs; dims=2)) log.(sum(exp.(xs); dims=2))
299299
@test @inferred(logsumexp(xs; dims=[1, 2])) log(sum(exp.(xs); dims=[1, 2]))
300300
@test @inferred(logsumexp(x for x in xs)) == logsumexp(xs)
301+
302+
# output arrays with abstract eltype
303+
xs = randn(2, 4)
304+
out = [missing, 1.0]
305+
expected = logsumexp(xs; dims=2)
306+
@test logsumexp!(out, xs) expected
307+
@test out expected
301308
end
302309

303310
@testset "softmax" begin

0 commit comments

Comments
 (0)