Skip to content

Commit 78636f1

Browse files
committed
Use pairwise summation and handle empty inputs
1 parent 97d3bfa commit 78636f1

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

src/scalarstats.jl

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,12 @@ Compute the entropy of a collection of probabilities `p`,
532532
optionally specifying a real number `b` such that the entropy is scaled by `1/log(b)`.
533533
Elements with probability 0 or 1 add 0 to the entropy.
534534
"""
535-
entropy(p) = -sum(xlogx, p)
535+
function entropy(p)
536+
if isempty(p)
537+
throw(ArgumentError("empty collections of probabilities are not supported"))
538+
end
539+
return -sum(xlogx, p)
540+
end
536541

537542
entropy(p, b::Real) = entropy(p) / log(b)
538543

@@ -586,7 +591,22 @@ number `b` such that the result is scaled by `1/log(b)`.
586591
"""
587592
function crossentropy(p::AbstractArray{<:Real}, q::AbstractArray{<:Real})
588593
length(p) == length(q) || throw(DimensionMismatch("Inconsistent array length."))
589-
return - sum(xlogy(pi, qi) for (pi, qi) in zip(p, q))
594+
595+
# handle empty collections
596+
if isempty(p)
597+
Base.depwarn(
598+
"support for empty collections of probabilities will be removed",
599+
:crossentropy,
600+
)
601+
# return zero for empty arrays
602+
return xlogy(zero(eltype(p)), zero(eltype(q)))
603+
end
604+
605+
# use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
606+
broadcasted = Broadcast.broadcasted(xlogy, vec(p), vec(q))
607+
result = - sum(Broadcast.instantiate(broadcasted))
608+
609+
return result
590610
end
591611

592612
crossentropy(p::AbstractArray{<:Real}, q::AbstractArray{<:Real}, b::Real) =
@@ -603,7 +623,28 @@ can be specified such that the divergence is scaled by `1/log(b)`.
603623
"""
604624
function kldivergence(p::AbstractArray{<:Real}, q::AbstractArray{<:Real})
605625
length(p) == length(q) || throw(DimensionMismatch("Inconsistent array length."))
606-
return sum(xlogy(pi, pi / qi) for (pi, qi) in zip(p, q))
626+
627+
# handle empty collections
628+
if isempty(p)
629+
Base.depwarn(
630+
"support for empty collections of probabilities will be removed",
631+
:kldivergence,
632+
)
633+
# return zero for empty arrays
634+
pzero = zero(eltype(p))
635+
qzero = zero(eltype(q))
636+
return xlogy(pzero, zero(pzero / qzero))
637+
end
638+
639+
# use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
640+
broadcasted = Broadcast.broadcasted(vec(p), vec(q)) do pi, qi
641+
# handle pi = qi = 0, otherwise `NaN` is returned
642+
piqi = iszero(pi) && iszero(qi) ? zero(pi / qi) : pi / qi
643+
return xlogy(pi, piqi)
644+
end
645+
result = sum(Broadcast.instantiate(broadcasted))
646+
647+
return result
607648
end
608649

609650
kldivergence(p::AbstractArray{<:Real}, q::AbstractArray{<:Real}, b::Real) =

test/scalarstats.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ it = (xᵢ for xᵢ in x)
165165
@test @inferred(entropy([1//2, 1//2], 2)) 1.0
166166
@test @inferred(entropy([0.2, 0.3, 0.5], 2)) 1.4854752972273344
167167

168+
@test_throws ArgumentError @inferred(entropy(Float64[]))
169+
@test_throws ArgumentError @inferred(entropy(Int[]))
170+
168171
##### Renyi entropies
169172
# Generate a random probability distribution
170173
nindiv = 50
@@ -211,13 +214,24 @@ scale = rand()
211214
@test @inferred(crossentropy([1//5, 3//10, 1//2], [0.3, 0.4, 0.3], 2)) 1.6124543443825532
212215
@test @inferred(crossentropy([1//5, 3//10, 1//2], [0.3f0, 0.4f0, 0.3f0], 2f0)) isa Float32
213216

217+
# deprecated, should throw an `ArgumentError` at some point
218+
logpattern = (:warn, "support for empty collections of probabilities will be removed")
219+
@test iszero(@test_logs logpattern @inferred(crossentropy(Float64[], Float64[])))
220+
@test iszero(@test_logs logpattern @inferred(crossentropy(Int[], Int[])))
221+
214222
##### KL divergence
215223
@test @inferred(kldivergence([0.2, 0.3, 0.5], [0.3, 0.4, 0.3])) 0.08801516852582819
216224
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3, 0.4, 0.3])) 0.08801516852582819
217225
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3f0, 0.4f0, 0.3f0])) isa Float32
218226
@test @inferred(kldivergence([0.2, 0.3, 0.5], [0.3, 0.4, 0.3], 2)) 0.12697904715521868
219227
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3, 0.4, 0.3], 2)) 0.12697904715521868
220228
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3f0, 0.4f0, 0.3f0], 2f0)) isa Float32
229+
@test iszero(@inferred(kldivergence([0, 1], [0f0, 1f0])))
230+
231+
# deprecated, should throw an `ArgumentError` at some point
232+
logpattern = (:warn, "support for empty collections of probabilities will be removed")
233+
@test iszero(@test_logs logpattern @inferred(kldivergence(Float64[], Float64[])))
234+
@test iszero(@test_logs logpattern @inferred(kldivergence(Int[], Int[])))
221235

222236
##### summarystats
223237

0 commit comments

Comments
 (0)