Skip to content

Commit 0f29efb

Browse files
authored
Merge pull request #35 from devmotion/params
Readd accidentally removed methods
2 parents afec90c + 7745edd commit 0f29efb

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/multivariate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ struct TuringDiagMvNormal{Tm<:AbstractVector, Tσ<:AbstractVector} <: Continuous
116116
σ::Tσ
117117
end
118118

119+
Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ)
119120
Base.length(d::TuringDiagMvNormal) = length(d.m)
120121
Base.size(d::TuringDiagMvNormal) = (length(d),)
121122
Distributions.rand(d::TuringDiagMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...)
@@ -128,6 +129,7 @@ struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivaria
128129
σ::Tσ
129130
end
130131

132+
Distributions.params(d::TuringScalMvNormal) = (d.m, d.σ)
131133
Base.length(d::TuringScalMvNormal) = length(d.m)
132134
Base.size(d::TuringScalMvNormal) = (length(d),)
133135
Distributions.rand(d::TuringScalMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...)

test/others.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,22 @@ end
125125
B = copy(A)
126126
@test DistributionsAD.zygote_ldiv(A, B) == A \ B
127127
end
128+
129+
@testset "Entropy" begin
130+
sigmas = exp.(randn(10))
131+
d1 = TuringDiagMvNormal(zeros(10), sigmas)
132+
d2 = MvNormal(zeros(10), sigmas)
133+
134+
@test entropy(d1) == entropy(d2)
135+
end
136+
137+
@testset "Params" begin
138+
m = rand(10)
139+
sigmas = randexp(10)
140+
141+
d = TuringDiagMvNormal(m, sigmas)
142+
@test params(d) == (m, sigmas)
143+
144+
d = TuringScalMvNormal(m, sigmas[1])
145+
@test params(d) == (m, sigmas[1])
146+
end

0 commit comments

Comments
 (0)