Skip to content

Commit 310bee9

Browse files
committed
Add Gibbs component call order test
1 parent cc9510c commit 310bee9

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

test/mcmc/gibbs.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,127 @@ end
130130
)
131131
end
132132

133+
# Test that the samplers are being called in the correct order, on the correct target
134+
# variables.
135+
@testset "Sampler call order" begin
136+
# A wrapper around inference algorithms to allow intercepting the dispatch cascade to
137+
# collect testing information.
138+
struct AlgWrapper{Alg<:Inference.InferenceAlgorithm} <: Inference.InferenceAlgorithm
139+
inner::Alg
140+
end
141+
142+
unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) =
143+
DynamicPPL.Sampler(sampler.alg.inner, sampler.selector)
144+
145+
# Methods we need to define to be able to use AlgWrapper instead of an actual algorithm.
146+
# They all just propagate the call to the inner algorithm.
147+
Inference.isgibbscomponent(wrap::AlgWrapper) = Inference.isgibbscomponent(wrap.inner)
148+
Inference.drop_space(wrap::AlgWrapper) = AlgWrapper(Inference.drop_space(wrap.inner))
149+
function Inference.setparams_varinfo!!(
150+
model::DynamicPPL.Model,
151+
sampler::DynamicPPL.Sampler{<:AlgWrapper},
152+
state,
153+
params::Turing.AbstractVarInfo,
154+
)
155+
return Inference.setparams_varinfo!!(model, unwrap_sampler(sampler), state, params)
156+
end
157+
158+
function target_vns(::Inference.GibbsContext{VNs}) where {VNs}
159+
return VNs
160+
end
161+
162+
# targets_and_algs will be a list of tuples, where the first element is the target_vns
163+
# of a component sampler, and the second element is the component sampler itself.
164+
# It is modified by the capture_targets_and_algs function.
165+
targets_and_algs = Any[]
166+
167+
function capture_targets_and_algs(sampler, context)
168+
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsLeaf()
169+
return nothing
170+
end
171+
if context isa Inference.GibbsContext
172+
push!(targets_and_algs, (target_vns(context), sampler))
173+
end
174+
return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context))
175+
end
176+
177+
# The methods that capture testing information for us.
178+
function Turing.AbstractMCMC.step(
179+
rng::Random.AbstractRNG,
180+
model::DynamicPPL.Model,
181+
sampler::DynamicPPL.Sampler{<:AlgWrapper},
182+
args...;
183+
kwargs...,
184+
)
185+
capture_targets_and_algs(sampler.alg.inner, model.context)
186+
return Turing.AbstractMCMC.step(
187+
rng, model, unwrap_sampler(sampler), args...; kwargs...
188+
)
189+
end
190+
191+
function Turing.DynamicPPL.initialstep(
192+
rng::Random.AbstractRNG,
193+
model::DynamicPPL.Model,
194+
sampler::DynamicPPL.Sampler{<:AlgWrapper},
195+
args...;
196+
kwargs...,
197+
)
198+
capture_targets_and_algs(sampler.alg.inner, model.context)
199+
return Turing.DynamicPPL.initialstep(
200+
rng, model, unwrap_sampler(sampler), args...; kwargs...
201+
)
202+
end
203+
204+
# A test model that includes several different kinds of tilde syntax.
205+
@model function test_model(val, ::Type{M}=Vector{Float64}) where {M}
206+
s ~ Normal(0.1, 0.2)
207+
m ~ Poisson()
208+
val ~ Normal(s, 1)
209+
1.0 ~ Normal(s + m, 1)
210+
211+
n := m + 1
212+
xs = M(undef, n)
213+
for i in eachindex(xs)
214+
xs[i] ~ Beta(0.5, 0.5)
215+
end
216+
217+
ys = M(undef, 2)
218+
ys .~ Beta(1.0, 1.0)
219+
return sum(xs), sum(ys), n
220+
end
221+
222+
mh = MH()
223+
pg = PG(10)
224+
hmc = HMC(0.01, 4)
225+
nuts = NUTS()
226+
# Sample with all sorts of combinations of samplers and targets.
227+
sampler = Gibbs(
228+
(@varname(s),) => AlgWrapper(mh),
229+
(@varname(s), @varname(m)) => AlgWrapper(mh),
230+
(@varname(m),) => AlgWrapper(pg),
231+
(@varname(xs),) => AlgWrapper(hmc),
232+
(@varname(ys),) => AlgWrapper(nuts),
233+
(@varname(ys),) => AlgWrapper(nuts),
234+
(@varname(xs), @varname(ys)) => AlgWrapper(hmc),
235+
(@varname(s),) => AlgWrapper(mh),
236+
)
237+
chain = sample(test_model(-1), sampler, 2)
238+
239+
expected_targets_and_algs_per_iteration = [
240+
((:s,), mh),
241+
((:s, :m), mh),
242+
((:m,), pg),
243+
((:xs,), hmc),
244+
((:ys,), nuts),
245+
((:ys,), nuts),
246+
((:xs, :ys), hmc),
247+
((:s,), mh),
248+
]
249+
@test targets_and_algs == vcat(
250+
expected_targets_and_algs_per_iteration, expected_targets_and_algs_per_iteration
251+
)
252+
end
253+
133254
@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
134255
@testset "Deprecated Gibbs constructors" begin
135256
N = 10

0 commit comments

Comments
 (0)