Skip to content

Commit 483f500

Browse files
committed
enhance wrapped distributions
1 parent 9ecf3dc commit 483f500

File tree

1 file changed

+30
-5
lines changed

1 file changed

+30
-5
lines changed

src/distribution_wrappers.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,34 @@ using Distributions: Distributions
22
using Bijectors: Bijectors
33
using Distributions: Univariate, Multivariate, Matrixvariate
44

5+
"""
6+
Base type for distribution wrappers.
7+
"""
8+
abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <:
9+
Distribution{variate,support}
10+
end
11+
12+
wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where Td = Td
13+
wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d)
14+
15+
wrapped_dist(d::WrappedDistribution) = d.dist
16+
17+
Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d))
18+
Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d))
19+
Base.eltype(::Type{T}) where T <: WrappedDistribution = eltype(wrapped_dist_type(T))
20+
Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d))
21+
22+
Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution) = rand(rng, wrapped_dist(d))
23+
Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d))
24+
Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d))
25+
26+
Bijectors.bijector(d::WrappedDistribution) = bijector(wrapped_dist(d))
27+
528
"""
629
A named distribution that carries the name of the random variable with it.
730
"""
831
struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <:
9-
Distribution{variate,support}
32+
WrappedDistribution{variate,support,Td}
1033
dist::Td
1134
name::Tv
1235
end
@@ -24,21 +47,23 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
2447
return Distributions.loglikelihood(dist.dist, x)
2548
end
2649

50+
"""
51+
Wrapper around distribution `Td` that suppresses `logpdf()` calculation.
52+
53+
Note that *SampleFromPrior* would still sample from `Td`.
54+
"""
2755
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
28-
Distribution{variate,support}
56+
WrappedDistribution{variate,support,Td}
2957
dist::Td
3058
end
3159
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)
3260

33-
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
3461
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
3562
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
3663
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
3764
return zeros(Int, size(x, 2))
3865
end
3966
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
40-
Distributions.minimum(d::NoDist) = minimum(d.dist)
41-
Distributions.maximum(d::NoDist) = maximum(d.dist)
4267

4368
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0
4469
Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0

0 commit comments

Comments
 (0)