Skip to content

Commit 0dd4ada

Browse files
alystAlexey Stukalov
authored andcommitted
enhance wrapped distributions
1 parent 715526f commit 0dd4ada

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

src/distribution_wrappers.jl

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,40 @@ using Distributions: Distributions
22
using Bijectors: Bijectors
33
using Distributions: Univariate, Multivariate, Matrixvariate
44

5+
"""
6+
Base type for distribution wrappers.
7+
"""
8+
abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <:
9+
Distribution{variate,support}
10+
end
11+
12+
wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where Td = Td
13+
wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d)
14+
15+
wrapped_dist(d::WrappedDistribution) = d.dist
16+
17+
Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d))
18+
Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d))
19+
Base.eltype(::Type{T}) where T <: WrappedDistribution = eltype(wrapped_dist_type(T))
20+
Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d))
21+
22+
Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution) = rand(rng, wrapped_dist(d))
23+
Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d))
24+
Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d))
25+
26+
Bijectors.bijector(d::WrappedDistribution) = bijector(wrapped_dist(d))
27+
528
"""
629
A named distribution that carries the name of the random variable with it.
730
"""
831
struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <:
9-
Distribution{variate,support}
32+
WrappedDistribution{variate,support,Td}
1033
dist::Td
1134
name::Tv
1235
end
1336

1437
NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())
1538

16-
Base.length(dist::NamedDist) = Base.length(dist.dist)
17-
Base.size(dist::NamedDist) = Base.size(dist.dist)
18-
1939
Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
2040
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
2141
return Distributions.logpdf(dist.dist, x)
@@ -27,29 +47,27 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
2747
return Distributions.loglikelihood(dist.dist, x)
2848
end
2949

30-
Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)
50+
"""
51+
Wrapper around distribution `Td` that suppresses `logpdf()` calculation.
3152
53+
Note that *SampleFromPrior* would still sample from `Td`.
54+
"""
3255
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
33-
Distribution{variate,support}
56+
WrappedDistribution{variate,support,Td}
3457
dist::Td
3558
end
3659
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)
3760

3861
nodist(dist::Distribution) = NoDist(dist)
3962
nodist(dists::AbstractArray) = nodist.(dists)
4063

41-
Base.length(dist::NoDist) = Base.length(dist.dist)
42-
Base.size(dist::NoDist) = Base.size(dist.dist)
43-
4464
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
4565
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
4666
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
4767
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
4868
return zeros(Int, size(x, 2))
4969
end
5070
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
51-
Distributions.minimum(d::NoDist) = minimum(d.dist)
52-
Distributions.maximum(d::NoDist) = maximum(d.dist)
5371

5472
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
5573
function Bijectors.logpdf_with_trans(
@@ -67,5 +85,3 @@ function Bijectors.logpdf_with_trans(
6785
)
6886
return 0
6987
end
70-
71-
Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)

0 commit comments

Comments
 (0)