Skip to content

Commit 38c75aa

Browse files
committed
addes test for entropy
1 parent 64a07d2 commit 38c75aa

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12-
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1312
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
13+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1414
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1515
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1616
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

src/multivariate.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ struct TuringDiagMvNormal{Tm<:AbstractVector, Tσ<:AbstractVector} <: Continuous
3030
σ::Tσ
3131
end
3232

33-
Distributions.params(d::TuringMvDiagNormal) = (d.m, d.σ)
34-
Distributions.dim(d::TuringMvDiagNormal) = length(d.m)
35-
Base.length(d::TuringMvDiagNormal) = length(d.m)
33+
Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ)
34+
Distributions.dim(d::TuringDiagMvNormal) = length(d.m)
35+
Base.length(d::TuringDiagMvNormal) = length(d.m)
3636
Base.size(d::TuringDiagMvNormal) = (length(d), length(d))
3737
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int)
3838
return d.m .+ d.σ .* randn(rng, length(d), n)
@@ -79,7 +79,7 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
7979
end
8080

8181
import StatsBase: entropy
82-
function entropy(d::TuringMvDiagNormal)
82+
function entropy(d::TuringDiagMvNormal)
8383
T = eltype(d.σ)
8484
return (length(d) * (T(log2π) + one(T)) / 2 + sum(log.(d.σ)))
8585
end

test/others.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using StatsBase: entropy
2+
13
@testset "Others" begin
24
@test fill(param(1.0), 3) isa TrackedArray
35
x = rand(3)
@@ -11,3 +13,11 @@
1113
B = copy(A)
1214
@test DistributionsAD.zygote_ldiv(A, B) == A \ B
1315
end
16+
17+
@testset "Extras from StatsBase.jl" begin
18+
sigmas = exp.(randn(10))
19+
d1 = TuringDiagMvNormal(zeros(10), sigmas)
20+
d2 = MvNormal(zeros(10), sigmas)
21+
22+
@test entropy(d1) == entropy(d2)
23+
end

0 commit comments

Comments
 (0)