|
237 | 237 | end
|
238 | 238 | end
|
239 | 239 |
|
240 |
| -@testset verbose = true "AD / SamplingContext" begin |
241 |
| - # AD tests for gradient-based samplers need to be run with SamplingContext |
242 |
| - # because samplers can potentially use this to define custom behaviour in |
243 |
| - # the tilde-pipeline and thus change the code executed during model |
244 |
| - # evaluation. |
245 |
| - @testset "adtype=$adtype" for adtype in ADTYPES |
246 |
| - @testset "alg=$alg" for alg in [ |
247 |
| - HMC(0.1, 10; adtype=adtype), |
248 |
| - HMCDA(0.8, 0.75; adtype=adtype), |
249 |
| - NUTS(1000, 0.8; adtype=adtype), |
250 |
| - SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), |
251 |
| - SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), |
252 |
| - ] |
253 |
| - @info "Testing AD for $alg" |
254 |
| - |
255 |
| - @testset "model=$(model.f)" for model in DEMO_MODELS |
256 |
| - rng = StableRNG(123) |
257 |
| - spl_model = DynamicPPL.contextualize( |
258 |
| - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) |
259 |
| - ) |
260 |
| - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any |
261 |
| - end |
262 |
| - end |
263 |
| - end |
264 |
| -end |
265 |
| - |
266 | 240 | @testset verbose = true "AD / GibbsContext" begin
|
267 |
| - # Gibbs sampling also needs extra AD testing because the models are |
| 241 | + # Gibbs sampling needs some extra AD testing because the models are |
268 | 242 | # executed with GibbsContext and a subsetted varinfo. (see e.g.
|
269 | 243 | # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
|
270 | 244 | # src/mcmc/gibbs.jl -- the code here mimics what happens in those
|
|
283 | 257 | model, varnames, deepcopy(global_vi)
|
284 | 258 | )
|
285 | 259 | rng = StableRNG(123)
|
286 |
| - spl_model = DynamicPPL.contextualize( |
287 |
| - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) |
288 |
| - ) |
289 |
| - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any |
| 260 | + @test run_ad(model, adtype; test=true, benchmark=false) isa Any |
290 | 261 | end
|
291 | 262 | end
|
292 | 263 | end
|
|
0 commit comments