@@ -2,20 +2,41 @@ using Distributions: Distributions
2
2
using Bijectors: Bijectors
3
3
using Distributions: Univariate, Multivariate, Matrixvariate
4
4
5
+ """
6
+ Base type for distribution wrappers.
7
+ """
8
+ abstract type WrappedDistribution{variate,support,Td<: Distribution{variate,support} } < :
9
+ Distribution{variate,support} end
10
+
11
+ wrapped_dist_type (:: Type{<:WrappedDistribution{<:Any,<:Any,Td}} ) where {Td} = Td
12
+ wrapped_dist_type (d:: WrappedDistribution ) = wrapped_dist_type (d)
13
+
14
+ wrapped_dist (d:: WrappedDistribution ) = d. dist
15
+
16
+ Base. length (d:: WrappedDistribution{<:Multivariate} ) = length (wrapped_dist (d))
17
+ Base. size (d:: WrappedDistribution{<:Multivariate} ) = size (wrapped_dist (d))
18
+ Base. eltype (:: Type{T} ) where {T<: WrappedDistribution } = eltype (wrapped_dist_type (T))
19
+ Base. eltype (d:: WrappedDistribution ) = eltype (wrapped_dist_type (d))
20
+
21
+ function Distributions. rand (rng:: Random.AbstractRNG , d:: WrappedDistribution )
22
+ return rand (rng, wrapped_dist (d))
23
+ end
24
+ Distributions. minimum (d:: WrappedDistribution ) = minimum (wrapped_dist (d))
25
+ Distributions. maximum (d:: WrappedDistribution ) = maximum (wrapped_dist (d))
26
+
27
+ Bijectors. bijector (d:: WrappedDistribution ) = bijector (wrapped_dist (d))
28
+
5
29
"""
6
30
A named distribution that carries the name of the random variable with it.
7
31
"""
8
32
struct NamedDist{variate,support,Td<: Distribution{variate,support} ,Tv<: VarName } < :
9
- Distribution {variate,support}
33
+ WrappedDistribution {variate,support,Td }
10
34
dist:: Td
11
35
name:: Tv
12
36
end
13
37
14
38
NamedDist (dist:: Distribution , name:: Symbol ) = NamedDist (dist, VarName {name} ())
15
39
16
- Base. length (dist:: NamedDist ) = Base. length (dist. dist)
17
- Base. size (dist:: NamedDist ) = Base. size (dist. dist)
18
-
19
40
Distributions. logpdf (dist:: NamedDist , x:: Real ) = Distributions. logpdf (dist. dist, x)
20
41
function Distributions. logpdf (dist:: NamedDist , x:: AbstractArray{<:Real} )
21
42
return Distributions. logpdf (dist. dist, x)
@@ -27,29 +48,27 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
27
48
return Distributions. loglikelihood (dist. dist, x)
28
49
end
29
50
30
- Bijectors. bijector (d:: NamedDist ) = Bijectors. bijector (d. dist)
51
+ """
52
+ Wrapper around distribution `Td` that suppresses `logpdf()` calculation.
31
53
54
+ Note that *SampleFromPrior* would still sample from `Td`.
55
+ """
32
56
struct NoDist{variate,support,Td<: Distribution{variate,support} } < :
33
- Distribution {variate,support}
57
+ WrappedDistribution {variate,support,Td }
34
58
dist:: Td
35
59
end
36
60
NoDist (dist:: NamedDist ) = NamedDist (NoDist (dist. dist), dist. name)
37
61
38
62
nodist (dist:: Distribution ) = NoDist (dist)
39
63
nodist (dists:: AbstractArray ) = nodist .(dists)
40
64
41
- Base. length (dist:: NoDist ) = Base. length (dist. dist)
42
- Base. size (dist:: NoDist ) = Base. size (dist. dist)
43
-
44
65
Distributions. rand (rng:: Random.AbstractRNG , d:: NoDist ) = rand (rng, d. dist)
45
66
Distributions. logpdf (d:: NoDist{<:Univariate} , :: Real ) = 0
46
67
Distributions. logpdf (d:: NoDist{<:Multivariate} , :: AbstractVector{<:Real} ) = 0
47
68
function Distributions. logpdf (d:: NoDist{<:Multivariate} , x:: AbstractMatrix{<:Real} )
48
69
return zeros (Int, size (x, 2 ))
49
70
end
50
71
Distributions. logpdf (d:: NoDist{<:Matrixvariate} , :: AbstractMatrix{<:Real} ) = 0
51
- Distributions. minimum (d:: NoDist ) = minimum (d. dist)
52
- Distributions. maximum (d:: NoDist ) = maximum (d. dist)
53
72
54
73
Bijectors. logpdf_with_trans (d:: NoDist{<:Univariate} , :: Real , :: Bool ) = 0
55
74
function Bijectors. logpdf_with_trans (
@@ -67,5 +86,3 @@ function Bijectors.logpdf_with_trans(
67
86
)
68
87
return 0
69
88
end
70
-
71
- Bijectors. bijector (d:: NoDist ) = Bijectors. bijector (d. dist)
0 commit comments