@@ -130,6 +130,127 @@ end
130
130
)
131
131
end
132
132
133
+ # Test that the samplers are being called in the correct order, on the correct target
134
+ # variables.
135
+ @testset " Sampler call order" begin
136
+ # A wrapper around inference algorithms to allow intercepting the dispatch cascade to
137
+ # collect testing information.
138
+ struct AlgWrapper{Alg<: Inference.InferenceAlgorithm } <: Inference.InferenceAlgorithm
139
+ inner:: Alg
140
+ end
141
+
142
+ unwrap_sampler (sampler:: DynamicPPL.Sampler{<:AlgWrapper} ) =
143
+ DynamicPPL. Sampler (sampler. alg. inner, sampler. selector)
144
+
145
+ # Methods we need to define to be able to use AlgWrapper instead of an actual algorithm.
146
+ # They all just propagate the call to the inner algorithm.
147
+ Inference. isgibbscomponent (wrap:: AlgWrapper ) = Inference. isgibbscomponent (wrap. inner)
148
+ Inference. drop_space (wrap:: AlgWrapper ) = AlgWrapper (Inference. drop_space (wrap. inner))
149
+ function Inference. setparams_varinfo!! (
150
+ model:: DynamicPPL.Model ,
151
+ sampler:: DynamicPPL.Sampler{<:AlgWrapper} ,
152
+ state,
153
+ params:: Turing.AbstractVarInfo ,
154
+ )
155
+ return Inference. setparams_varinfo!! (model, unwrap_sampler (sampler), state, params)
156
+ end
157
+
158
+ function target_vns (:: Inference.GibbsContext{VNs} ) where {VNs}
159
+ return VNs
160
+ end
161
+
162
+ # targets_and_algs will be a list of tuples, where the first element is the target_vns
163
+ # of a component sampler, and the second element is the component sampler itself.
164
+ # It is modified by the capture_targets_and_algs function.
165
+ targets_and_algs = Any[]
166
+
167
+ function capture_targets_and_algs (sampler, context)
168
+ if DynamicPPL. NodeTrait (context) == DynamicPPL. IsLeaf ()
169
+ return nothing
170
+ end
171
+ if context isa Inference. GibbsContext
172
+ push! (targets_and_algs, (target_vns (context), sampler))
173
+ end
174
+ return capture_targets_and_algs (sampler, DynamicPPL. childcontext (context))
175
+ end
176
+
177
+ # The methods that capture testing information for us.
178
+ function Turing. AbstractMCMC. step (
179
+ rng:: Random.AbstractRNG ,
180
+ model:: DynamicPPL.Model ,
181
+ sampler:: DynamicPPL.Sampler{<:AlgWrapper} ,
182
+ args... ;
183
+ kwargs... ,
184
+ )
185
+ capture_targets_and_algs (sampler. alg. inner, model. context)
186
+ return Turing. AbstractMCMC. step (
187
+ rng, model, unwrap_sampler (sampler), args... ; kwargs...
188
+ )
189
+ end
190
+
191
+ function Turing. DynamicPPL. initialstep (
192
+ rng:: Random.AbstractRNG ,
193
+ model:: DynamicPPL.Model ,
194
+ sampler:: DynamicPPL.Sampler{<:AlgWrapper} ,
195
+ args... ;
196
+ kwargs... ,
197
+ )
198
+ capture_targets_and_algs (sampler. alg. inner, model. context)
199
+ return Turing. DynamicPPL. initialstep (
200
+ rng, model, unwrap_sampler (sampler), args... ; kwargs...
201
+ )
202
+ end
203
+
204
+ # A test model that includes several different kinds of tilde syntax.
205
+ @model function test_model (val, :: Type{M} = Vector{Float64}) where {M}
206
+ s ~ Normal (0.1 , 0.2 )
207
+ m ~ Poisson ()
208
+ val ~ Normal (s, 1 )
209
+ 1.0 ~ Normal (s + m, 1 )
210
+
211
+ n := m + 1
212
+ xs = M (undef, n)
213
+ for i in eachindex (xs)
214
+ xs[i] ~ Beta (0.5 , 0.5 )
215
+ end
216
+
217
+ ys = M (undef, 2 )
218
+ ys .~ Beta (1.0 , 1.0 )
219
+ return sum (xs), sum (ys), n
220
+ end
221
+
222
+ mh = MH ()
223
+ pg = PG (10 )
224
+ hmc = HMC (0.01 , 4 )
225
+ nuts = NUTS ()
226
+ # Sample with all sorts of combinations of samplers and targets.
227
+ sampler = Gibbs (
228
+ (@varname (s),) => AlgWrapper (mh),
229
+ (@varname (s), @varname (m)) => AlgWrapper (mh),
230
+ (@varname (m),) => AlgWrapper (pg),
231
+ (@varname (xs),) => AlgWrapper (hmc),
232
+ (@varname (ys),) => AlgWrapper (nuts),
233
+ (@varname (ys),) => AlgWrapper (nuts),
234
+ (@varname (xs), @varname (ys)) => AlgWrapper (hmc),
235
+ (@varname (s),) => AlgWrapper (mh),
236
+ )
237
+ chain = sample (test_model (- 1 ), sampler, 2 )
238
+
239
+ expected_targets_and_algs_per_iteration = [
240
+ ((:s ,), mh),
241
+ ((:s , :m ), mh),
242
+ ((:m ,), pg),
243
+ ((:xs ,), hmc),
244
+ ((:ys ,), nuts),
245
+ ((:ys ,), nuts),
246
+ ((:xs , :ys ), hmc),
247
+ ((:s ,), mh),
248
+ ]
249
+ @test targets_and_algs == vcat (
250
+ expected_targets_and_algs_per_iteration, expected_targets_and_algs_per_iteration
251
+ )
252
+ end
253
+
133
254
@testset " Testing gibbs.jl with $adbackend " for adbackend in ADUtils. adbackends
134
255
@testset " Deprecated Gibbs constructors" begin
135
256
N = 10
0 commit comments