Skip to content

Commit 17bffbd

Browse files
authored
Try #414:
2 parents 8c8cfc6 + 078731f commit 17bffbd

File tree

5 files changed

+99
-22
lines changed

5 files changed

+99
-22
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ pointwise_loglikelihoods
9595
```
9696

9797
```@docs
98+
WrappedDistribution
9899
NamedDist
100+
NoDist
99101
```
100102

101103
## Testing Utilities

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ export AbstractVarInfo,
105105
dot_tilde_assume,
106106
dot_tilde_observe,
107107
# Pseudo distributions
108+
WrappedDistribution,
108109
NamedDist,
109110
NoDist,
110111
# Prob macros

src/distribution_wrappers.jl

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,41 @@ using Distributions: Distributions
22
using Bijectors: Bijectors
33
using Distributions: Univariate, Multivariate, Matrixvariate
44

5+
"""
6+
Base type for distribution wrappers.
7+
"""
8+
abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <:
9+
Distribution{variate,support} end
10+
11+
wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where {Td} = Td
12+
wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d)
13+
14+
wrapped_dist(d::WrappedDistribution) = d.dist
15+
16+
Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d))
17+
Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d))
18+
Base.eltype(::Type{T}) where {T<:WrappedDistribution} = eltype(wrapped_dist_type(T))
19+
Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d))
20+
21+
function Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution)
22+
return rand(rng, wrapped_dist(d))
23+
end
24+
Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d))
25+
Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d))
26+
27+
Bijectors.bijector(d::WrappedDistribution) = bijector(wrapped_dist(d))
28+
529
"""
630
A named distribution that carries the name of the random variable with it.
731
"""
832
struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <:
9-
Distribution{variate,support}
33+
WrappedDistribution{variate,support,Td}
1034
dist::Td
1135
name::Tv
1236
end
1337

1438
NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())
1539

16-
Base.length(dist::NamedDist) = Base.length(dist.dist)
17-
Base.size(dist::NamedDist) = Base.size(dist.dist)
18-
1940
Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
2041
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
2142
return Distributions.logpdf(dist.dist, x)
@@ -27,29 +48,27 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
2748
return Distributions.loglikelihood(dist.dist, x)
2849
end
2950

30-
Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)
51+
"""
52+
Wrapper around distribution `Td` that suppresses `logpdf()` calculation.
3153
54+
Note that *SampleFromPrior* would still sample from `Td`.
55+
"""
3256
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
33-
Distribution{variate,support}
57+
WrappedDistribution{variate,support,Td}
3458
dist::Td
3559
end
3660
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)
3761

3862
nodist(dist::Distribution) = NoDist(dist)
3963
nodist(dists::AbstractArray) = nodist.(dists)
4064

41-
Base.length(dist::NoDist) = Base.length(dist.dist)
42-
Base.size(dist::NoDist) = Base.size(dist.dist)
43-
4465
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
4566
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
4667
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
4768
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
4869
return zeros(Int, size(x, 2))
4970
end
5071
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
51-
Distributions.minimum(d::NoDist) = minimum(d.dist)
52-
Distributions.maximum(d::NoDist) = maximum(d.dist)
5372

5473
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
5574
function Bijectors.logpdf_with_trans(
@@ -67,5 +86,3 @@ function Bijectors.logpdf_with_trans(
6786
)
6887
return 0
6988
end
70-
71-
Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)

test/context_implementations.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,42 @@
6868
end
6969
end
7070
end
71+
72+
@testset "multivariate NoDist" begin
73+
@model function genmodel()
74+
x ~ NoDist(Product(fill(Uniform(-20, 20), 5)))
75+
for i in eachindex(x)
76+
x[i] ~ Normal(0, 1)
77+
end
78+
end
79+
gen_model = genmodel()
80+
vi_gen = VarInfo(gen_model)
81+
@test isfinite(logjoint(gen_model, vi_gen))
82+
# test for bijector
83+
link!(vi_gen, DynamicPPL.SampleFromPrior())
84+
invlink!(vi_gen, DynamicPPL.SampleFromPrior())
85+
86+
# explicit model specification
87+
expl_model = DynamicPPL.Model(NamedTuple()) do model, varinfo, context
88+
DynamicPPL.tilde_assume!!(
89+
context,
90+
NoDist(Product(fill(Uniform(-20, 20), 5))),
91+
@varname(x),
92+
varinfo,
93+
)
94+
x = varinfo[@varname(x)]
95+
@test x isa Vector{<:Real}
96+
@test length(x) == 5
97+
return (
98+
nothing,
99+
DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x))),
100+
)
101+
end
102+
vi_expl = VarInfo(expl_model)
103+
@test isfinite(logjoint(expl_model, vi_expl))
104+
# test for bijector
105+
link!(vi_expl, DynamicPPL.SampleFromPrior())
106+
invlink!(vi_expl, DynamicPPL.SampleFromPrior())
107+
end
71108
end
72109
end

test/distribution_wrappers.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,33 @@
11
@testset "distribution_wrappers.jl" begin
2-
d = Normal()
3-
nd = DynamicPPL.NoDist(d)
2+
@testset "univariate" begin
3+
d = Normal()
4+
nd = DynamicPPL.NoDist(d)
45

5-
# Smoke test
6-
rand(nd)
6+
# Smoke test
7+
rand(nd)
78

8-
# Actual tests
9-
@test minimum(nd) == -Inf
10-
@test maximum(nd) == Inf
11-
@test logpdf(nd, 15.0) == 0
12-
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0
9+
# Actual tests
10+
@test minimum(nd) == -Inf
11+
@test maximum(nd) == Inf
12+
@test logpdf(nd, 15.0) == 0
13+
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0
14+
@test Bijectors.bijector(nd) == Bijectors.bijector(d)
15+
end
16+
17+
@testset "multivariate" begin
18+
d = Product([Normal(), Uniform()])
19+
nd = DynamicPPL.NoDist(d)
20+
21+
# Smoke test
22+
@test length(rand(nd)) == 2
23+
24+
# Actual tests
25+
@test length(nd) == 2
26+
@test size(nd) == (2,)
27+
@test minimum(nd) == [-Inf, 0.0]
28+
@test maximum(nd) == [Inf, 1.0]
29+
@test logpdf(nd, [15.0, 0.5]) == 0
30+
@test Bijectors.logpdf_with_trans(nd, [0, 1]) == 0
31+
@test Bijectors.bijector(nd) == Bijectors.bijector(d)
32+
end
1333
end

0 commit comments

Comments
 (0)