Skip to content

Commit a2c2439

Browse files
committed
added length, params, and entropy for TuringDiagNormal
1 parent a13a865 commit a2c2439

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1314
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1415
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

src/multivariate.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ struct TuringDiagNormal{Tm<:AbstractVector, Tσ<:AbstractVector} <: ContinuousMu
2727
σ::Tσ
2828
end
2929

30+
Distributions.params(d::TuringDiagNormal) = (d.m, d.σ)
31+
Distributions.length(d::TuringDiagNormal) = length(d.m)
3032
Distributions.dim(d::TuringDiagNormal) = length(d.m)
3133
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagNormal)
3234
return d.m .+ d.σ .* randn(rng, dim(d))
@@ -55,6 +57,12 @@ function _logpdf(d::MvNormal, x::Union{Tracker.TrackedVector, Tracker.TrackedMat
5557
_logpdf(TuringMvNormal(d.μ, getchol(d.Σ)), x)
5658
end
5759

60+
import StatsBase: entropy
61+
function entropy(d::TuringDiagNormal)
62+
T = eltype(d.σ)
63+
return (length(d) * (T(log2π) + one(T)) / 2 + sum(log.(d.σ)))
64+
end
65+
5866
# zero mean, dense covariance
5967
MvNormal(A::TrackedMatrix) = MvNormal(zeros(size(A, 1)), A)
6068

0 commit comments

Comments
 (0)