Skip to content

Commit c4432ab

Browse files
authored
Merge pull request #714 from JuliaStats/dw/entropy
Fix type instability of `entropy` and generalize `crossentropy` and `kldivergence`
2 parents 0a17953 + e1d1d10 commit c4432ab

File tree

5 files changed

+84
-35
lines changed

5 files changed

+84
-35
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "StatsBase"
22
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
33
authors = ["JuliaStats"]
4-
version = "0.33.10"
4+
version = "0.33.11"
55

66
[deps]
77
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
88
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1011
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
1112
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -18,6 +19,7 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
1819
[compat]
1920
DataAPI = "1"
2021
DataStructures = "0.10, 0.11, 0.12, 0.13, 0.14, 0.17, 0.18"
22+
LogExpFunctions = "0.3"
2123
Missings = "0.3, 0.4, 1.0"
2224
SortingAlgorithms = "0.3, 1.0"
2325
StatsAPI = "1"

src/StatsBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import DataAPI: describe
88
import DataStructures: heapify!, heappop!, percolate_down!
99
using SortingAlgorithms
1010
using Missings
11+
using LogExpFunctions: xlogx, xlogy
1112

1213
using Statistics
1314
using LinearAlgebra

src/scalarstats.jl

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,13 @@ Compute the entropy of a collection of probabilities `p`,
532532
optionally specifying a real number `b` such that the entropy is scaled by `1/log(b)`.
533533
Elements with probability 0 or 1 add 0 to the entropy.
534534
"""
535-
entropy(p) = -sum(pᵢ -> iszero(pᵢ) ? zero(pᵢ) : pᵢ * log(pᵢ), p)
535+
function entropy(p)
536+
if isempty(p)
537+
throw(ArgumentError("empty collections are not supported since they do not " *
538+
"represent proper probability distributions"))
539+
end
540+
return -sum(xlogx, p)
541+
end
536542

537543
entropy(p, b::Real) = entropy(p) / log(b)
538544

@@ -584,21 +590,26 @@ end
584590
Compute the cross entropy between `p` and `q`, optionally specifying a real
585591
number `b` such that the result is scaled by `1/log(b)`.
586592
"""
587-
function crossentropy(p::AbstractArray{T}, q::AbstractArray{T}) where T<:Real
593+
function crossentropy(p::AbstractArray{<:Real}, q::AbstractArray{<:Real})
588594
length(p) == length(q) || throw(DimensionMismatch("Inconsistent array length."))
589-
s = 0.
590-
z = zero(T)
591-
for i = 1:length(p)
592-
@inbounds pi = p[i]
593-
@inbounds qi = q[i]
594-
if pi > z
595-
s += pi * log(qi)
596-
end
595+
596+
# handle empty collections
597+
if isempty(p)
598+
Base.depwarn(
599+
"support for empty collections will be removed since they do not " *
600+
"represent proper probability distributions",
601+
:crossentropy,
602+
)
603+
# return zero for empty arrays
604+
return xlogy(zero(eltype(p)), zero(eltype(q)))
597605
end
598-
return -s
606+
607+
# use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
608+
broadcasted = Broadcast.broadcasted(xlogy, vec(p), vec(q))
609+
return - sum(Broadcast.instantiate(broadcasted))
599610
end
600611

601-
crossentropy(p::AbstractArray{T}, q::AbstractArray{T}, b::Real) where {T<:Real} =
612+
crossentropy(p::AbstractArray{<:Real}, q::AbstractArray{<:Real}, b::Real) =
602613
crossentropy(p,q) / log(b)
603614

604615

@@ -610,21 +621,32 @@ also called the relative entropy of `p` with respect to `q`,
610621
that is the sum `pᵢ * log(pᵢ / qᵢ)`. Optionally a real number `b`
611622
can be specified such that the divergence is scaled by `1/log(b)`.
612623
"""
613-
function kldivergence(p::AbstractArray{T}, q::AbstractArray{T}) where T<:Real
624+
function kldivergence(p::AbstractArray{<:Real}, q::AbstractArray{<:Real})
614625
length(p) == length(q) || throw(DimensionMismatch("Inconsistent array length."))
615-
s = 0.
616-
z = zero(T)
617-
for i = 1:length(p)
618-
@inbounds pi = p[i]
619-
@inbounds qi = q[i]
620-
if pi > z
621-
s += pi * log(pi / qi)
622-
end
626+
627+
# handle empty collections
628+
if isempty(p)
629+
Base.depwarn(
630+
"support for empty collections will be removed since they do not "*
631+
"represent proper probability distributions",
632+
:kldivergence,
633+
)
634+
# return zero for empty arrays
635+
pzero = zero(eltype(p))
636+
qzero = zero(eltype(q))
637+
return xlogy(pzero, zero(pzero / qzero))
623638
end
624-
return s
639+
640+
# use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
641+
broadcasted = Broadcast.broadcasted(vec(p), vec(q)) do pi, qi
642+
# handle pi = qi = 0, otherwise `NaN` is returned
643+
piqi = iszero(pi) && iszero(qi) ? zero(pi / qi) : pi / qi
644+
return xlogy(pi, piqi)
645+
end
646+
return sum(Broadcast.instantiate(broadcasted))
625647
end
626648

627-
kldivergence(p::AbstractArray{T}, q::AbstractArray{T}, b::Real) where {T<:Real} =
649+
kldivergence(p::AbstractArray{<:Real}, q::AbstractArray{<:Real}, b::Real) =
628650
kldivergence(p,q) / log(b)
629651

630652
#############################

test/REQUIRE

Lines changed: 0 additions & 2 deletions
This file was deleted.

test/scalarstats.jl

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,19 @@ it = (xᵢ for xᵢ in x)
154154

155155
##### entropy
156156

157-
@test entropy([0.5, 0.5]) 0.6931471805599453
158-
@test entropy([0.2, 0.3, 0.5]) 1.0296530140645737
157+
@test @inferred(entropy([0.5, 0.5])) 0.6931471805599453
158+
@test @inferred(entropy([1//2, 1//2])) 0.6931471805599453
159+
@test @inferred(entropy([0.5f0, 0.5f0])) isa Float32
160+
@test @inferred(entropy([0.2, 0.3, 0.5])) 1.0296530140645737
161+
@test iszero(@inferred(entropy([0, 1])))
162+
@test iszero(@inferred(entropy([0.0, 1.0])))
159163

160-
@test entropy([0.5, 0.5],2) 1.0
161-
@test entropy([0.2, 0.3, 0.5], 2) 1.4854752972273344
162-
@test entropy([1.0, 0.0]) 0.0
164+
@test @inferred(entropy([0.5, 0.5], 2)) 1.0
165+
@test @inferred(entropy([1//2, 1//2], 2)) 1.0
166+
@test @inferred(entropy([0.2, 0.3, 0.5], 2)) 1.4854752972273344
167+
168+
@test_throws ArgumentError @inferred(entropy(Float64[]))
169+
@test_throws ArgumentError @inferred(entropy(Int[]))
163170

164171
##### Renyi entropies
165172
# Generate a random probability distribution
@@ -200,12 +207,31 @@ scale = rand()
200207
@test renyientropy(udist * scale, order) renyientropy(udist, order) - log(scale)
201208

202209
##### Cross entropy
203-
@test crossentropy([0.2, 0.3, 0.5], [0.3, 0.4, 0.3]) 1.1176681825904018
204-
@test crossentropy([0.2, 0.3, 0.5], [0.3, 0.4, 0.3], 2) 1.6124543443825532
210+
@test @inferred(crossentropy([0.2, 0.3, 0.5], [0.3, 0.4, 0.3])) 1.1176681825904018
211+
@test @inferred(crossentropy([1//5, 3//10, 1//2], [0.3, 0.4, 0.3])) 1.1176681825904018
212+
@test @inferred(crossentropy([1//5, 3//10, 1//2], [0.3f0, 0.4f0, 0.3f0])) isa Float32
213+
@test @inferred(crossentropy([0.2, 0.3, 0.5], [0.3, 0.4, 0.3], 2)) 1.6124543443825532
214+
@test @inferred(crossentropy([1//5, 3//10, 1//2], [0.3, 0.4, 0.3], 2)) 1.6124543443825532
215+
@test @inferred(crossentropy([1//5, 3//10, 1//2], [0.3f0, 0.4f0, 0.3f0], 2f0)) isa Float32
216+
217+
# deprecated, should throw an `ArgumentError` at some point
218+
logpattern = (:warn, "support for empty collections will be removed since they do not represent proper probability distributions")
219+
@test iszero(@test_logs logpattern @inferred(crossentropy(Float64[], Float64[])))
220+
@test iszero(@test_logs logpattern @inferred(crossentropy(Int[], Int[])))
205221

206222
##### KL divergence
207-
@test kldivergence([0.2, 0.3, 0.5], [0.3, 0.4, 0.3]) 0.08801516852582819
208-
@test kldivergence([0.2, 0.3, 0.5], [0.3, 0.4, 0.3], 2) 0.12697904715521868
223+
@test @inferred(kldivergence([0.2, 0.3, 0.5], [0.3, 0.4, 0.3])) 0.08801516852582819
224+
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3, 0.4, 0.3])) 0.08801516852582819
225+
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3f0, 0.4f0, 0.3f0])) isa Float32
226+
@test @inferred(kldivergence([0.2, 0.3, 0.5], [0.3, 0.4, 0.3], 2)) 0.12697904715521868
227+
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3, 0.4, 0.3], 2)) 0.12697904715521868
228+
@test @inferred(kldivergence([1//5, 3//10, 1//2], [0.3f0, 0.4f0, 0.3f0], 2f0)) isa Float32
229+
@test iszero(@inferred(kldivergence([0, 1], [0f0, 1f0])))
230+
231+
# deprecated, should throw an `ArgumentError` at some point
232+
logpattern = (:warn, "support for empty collections will be removed since they do not represent proper probability distributions")
233+
@test iszero(@test_logs logpattern @inferred(kldivergence(Float64[], Float64[])))
234+
@test iszero(@test_logs logpattern @inferred(kldivergence(Int[], Int[])))
209235

210236
##### summarystats
211237

0 commit comments

Comments
 (0)