@@ -2,20 +2,40 @@ using Distributions: Distributions
2
2
using Bijectors: Bijectors
3
3
using Distributions: Univariate, Multivariate, Matrixvariate
4
4
5
+ """
6
+ Base type for distribution wrappers.
7
+ """
8
+ abstract type WrappedDistribution{variate,support,Td<: Distribution{variate,support} } < :
9
+ Distribution{variate,support}
10
+ end
11
+
12
+ wrapped_dist_type (:: Type{<:WrappedDistribution{<:Any,<:Any,Td}} ) where Td = Td
13
+ wrapped_dist_type (d:: WrappedDistribution ) = wrapped_dist_type (d)
14
+
15
+ wrapped_dist (d:: WrappedDistribution ) = d. dist
16
+
17
+ Base. length (d:: WrappedDistribution{<:Multivariate} ) = length (wrapped_dist (d))
18
+ Base. size (d:: WrappedDistribution{<:Multivariate} ) = size (wrapped_dist (d))
19
+ Base. eltype (:: Type{T} ) where T <: WrappedDistribution = eltype (wrapped_dist_type (T))
20
+ Base. eltype (d:: WrappedDistribution ) = eltype (wrapped_dist_type (d))
21
+
22
+ Distributions. rand (rng:: Random.AbstractRNG , d:: WrappedDistribution ) = rand (rng, wrapped_dist (d))
23
+ Distributions. minimum (d:: WrappedDistribution ) = minimum (wrapped_dist (d))
24
+ Distributions. maximum (d:: WrappedDistribution ) = maximum (wrapped_dist (d))
25
+
26
+ Bijectors. bijector (d:: WrappedDistribution ) = bijector (wrapped_dist (d))
27
+
5
28
"""
6
29
A named distribution that carries the name of the random variable with it.
7
30
"""
8
31
struct NamedDist{variate,support,Td<: Distribution{variate,support} ,Tv<: VarName } < :
9
- Distribution {variate,support}
32
+ WrappedDistribution {variate,support,Td }
10
33
dist:: Td
11
34
name:: Tv
12
35
end
13
36
14
37
NamedDist (dist:: Distribution , name:: Symbol ) = NamedDist (dist, VarName {name} ())
15
38
16
- Base. length (dist:: NamedDist ) = Base. length (dist. dist)
17
- Base. size (dist:: NamedDist ) = Base. size (dist. dist)
18
-
19
39
Distributions. logpdf (dist:: NamedDist , x:: Real ) = Distributions. logpdf (dist. dist, x)
20
40
function Distributions. logpdf (dist:: NamedDist , x:: AbstractArray{<:Real} )
21
41
return Distributions. logpdf (dist. dist, x)
@@ -27,29 +47,27 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
27
47
return Distributions. loglikelihood (dist. dist, x)
28
48
end
29
49
30
- Bijectors. bijector (d:: NamedDist ) = Bijectors. bijector (d. dist)
50
+ """
51
+ Wrapper around distribution `Td` that suppresses `logpdf()` calculation.
31
52
53
+ Note that *SampleFromPrior* would still sample from `Td`.
54
+ """
32
55
struct NoDist{variate,support,Td<: Distribution{variate,support} } < :
33
- Distribution {variate,support}
56
+ WrappedDistribution {variate,support,Td }
34
57
dist:: Td
35
58
end
36
59
NoDist (dist:: NamedDist ) = NamedDist (NoDist (dist. dist), dist. name)
37
60
38
61
nodist (dist:: Distribution ) = NoDist (dist)
39
62
nodist (dists:: AbstractArray ) = nodist .(dists)
40
63
41
- Base. length (dist:: NoDist ) = Base. length (dist. dist)
42
- Base. size (dist:: NoDist ) = Base. size (dist. dist)
43
-
44
64
Distributions. rand (rng:: Random.AbstractRNG , d:: NoDist ) = rand (rng, d. dist)
45
65
Distributions. logpdf (d:: NoDist{<:Univariate} , :: Real ) = 0
46
66
Distributions. logpdf (d:: NoDist{<:Multivariate} , :: AbstractVector{<:Real} ) = 0
47
67
function Distributions. logpdf (d:: NoDist{<:Multivariate} , x:: AbstractMatrix{<:Real} )
48
68
return zeros (Int, size (x, 2 ))
49
69
end
50
70
Distributions. logpdf (d:: NoDist{<:Matrixvariate} , :: AbstractMatrix{<:Real} ) = 0
51
- Distributions. minimum (d:: NoDist ) = minimum (d. dist)
52
- Distributions. maximum (d:: NoDist ) = maximum (d. dist)
53
71
54
72
Bijectors. logpdf_with_trans (d:: NoDist{<:Univariate} , :: Real , :: Bool ) = 0
55
73
function Bijectors. logpdf_with_trans (
@@ -67,5 +85,3 @@ function Bijectors.logpdf_with_trans(
67
85
)
68
86
return 0
69
87
end
70
-
71
- Bijectors. bijector (d:: NoDist ) = Bijectors. bijector (d. dist)
0 commit comments