Skip to content

Commit 8af36e1

Browse files
committed
No need to test AD for SamplingContext{<:HMC} (#2645)
1 parent d92fd56 commit 8af36e1

File tree

1 file changed

+2
-31
lines changed

1 file changed

+2
-31
lines changed

test/ad.jl

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -237,34 +237,8 @@ end
237237
end
238238
end
239239

240-
@testset verbose = true "AD / SamplingContext" begin
241-
# AD tests for gradient-based samplers need to be run with SamplingContext
242-
# because samplers can potentially use this to define custom behaviour in
243-
# the tilde-pipeline and thus change the code executed during model
244-
# evaluation.
245-
@testset "adtype=$adtype" for adtype in ADTYPES
246-
@testset "alg=$alg" for alg in [
247-
HMC(0.1, 10; adtype=adtype),
248-
HMCDA(0.8, 0.75; adtype=adtype),
249-
NUTS(1000, 0.8; adtype=adtype),
250-
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
251-
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
252-
]
253-
@info "Testing AD for $alg"
254-
255-
@testset "model=$(model.f)" for model in DEMO_MODELS
256-
rng = StableRNG(123)
257-
spl_model = DynamicPPL.contextualize(
258-
model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
259-
)
260-
@test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any
261-
end
262-
end
263-
end
264-
end
265-
266240
@testset verbose = true "AD / GibbsContext" begin
267-
# Gibbs sampling also needs extra AD testing because the models are
241+
# Gibbs sampling needs some extra AD testing because the models are
268242
# executed with GibbsContext and a subsetted varinfo. (see e.g.
269243
# `gibbs_initialstep_recursive` and `gibbs_step_recursive` in
270244
# src/mcmc/gibbs.jl -- the code here mimics what happens in those
@@ -283,10 +257,7 @@ end
283257
model, varnames, deepcopy(global_vi)
284258
)
285259
rng = StableRNG(123)
286-
spl_model = DynamicPPL.contextualize(
287-
model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10)))
288-
)
289-
@test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any
260+
@test run_ad(model, adtype; test=true, benchmark=false) isa Any
290261
end
291262
end
292263
end

0 commit comments

Comments
 (0)