Skip to content

Commit 7369221

Browse files
committed
Remove AD loops from HMC and SGHMC
1 parent df85c6f commit 7369221

File tree

3 files changed

+51
-27
lines changed

3 files changed

+51
-27
lines changed

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Combinatorics = "1"
5252
Distributions = "0.25"
5353
DistributionsAD = "0.6.3"
5454
DynamicHMC = "2.1.6, 3.0"
55-
DynamicPPL = "0.36"
55+
DynamicPPL = "0.36.6"
5656
FiniteDifferences = "0.10.8, 0.11, 0.12"
5757
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
5858
HypothesisTests = "0.11"

test/mcmc/hmc.jl

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import ..ADUtils
77
using Bijectors: Bijectors
88
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
99
using DynamicPPL: DynamicPPL, Sampler
10+
using DynamicPPL.TestUtils.AD: run_ad
11+
using DynamicPPL.TestUtils: DEMO_MODELS
1012
import ForwardDiff
1113
using HypothesisTests: ApproximateTwoSampleKSTest, pvalue
1214
import ReverseDiff
@@ -18,9 +20,41 @@ import Mooncake
1820
using Test: @test, @test_logs, @testset, @test_throws
1921
using Turing
2022

21-
@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends
22-
@info "Starting HMC tests with $adbackend"
23+
@testset "AD / hmc.jl" begin
24+
# AD tests need to be run with SamplingContext because samplers can potentially
25+
# use this to define custom behaviour in the tilde-pipeline and thus change the
26+
# code executed during model evaluation.
27+
@testset "adtype=$adtype" for adtype in ADUtils.adbackends
28+
@testset "alg=$alg" for alg in [
29+
HMC(0.1, 10; adtype=adtype),
30+
HMCDA(0.8, 0.75; adtype=adtype),
31+
NUTS(1000, 0.8; adtype=adtype),
32+
]
33+
@info "Testing AD for $alg"
34+
35+
@testset "model=$(model.f)" for model in DEMO_MODELS
36+
rng = StableRNG(123)
37+
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
38+
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
39+
end
40+
end
41+
42+
@testset "Check ADType" begin
43+
seed = 123
44+
alg = HMC(0.1, 10; adtype=adtype)
45+
m = DynamicPPL.contextualize(
46+
gdemo_default, ADTypeCheckContext(adtype, gdemo_default.context)
47+
)
48+
# These will error if the adbackend being used is not the one set.
49+
sample(StableRNG(seed), m, alg, 10)
50+
end
51+
end
52+
end
53+
54+
@testset "Testing hmc.jl" begin
55+
@info "Starting HMC tests"
2356
seed = 123
57+
adbackend = Turing.DEFAULT_ADTYPE
2458

2559
@testset "constrained bounded" begin
2660
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
@@ -65,12 +99,6 @@ using Turing
6599
check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015)
66100
end
67101

68-
@testset "hmc reverse diff" begin
69-
alg = HMC(0.1, 10; adtype=adbackend)
70-
res = sample(StableRNG(seed), gdemo_default, alg, 4_000)
71-
check_gdemo(res; rtol=0.1)
72-
end
73-
74102
# Test the sampling of a matrix-value distribution.
75103
@testset "matrix support" begin
76104
dist = Wishart(7, [1 0.5; 0.5 1])
@@ -211,20 +239,20 @@ using Turing
211239
end
212240

213241
@testset "prior" begin
242+
# NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
243+
# which means that it's _very_ difficult to find a good tolerance in the test below:)
244+
prior_dist = truncated(Normal(3, 1); lower=0)
245+
214246
@model function demo_hmc_prior()
215-
# NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
216-
# which means that it's _very_ difficult to find a good tolerance in the test below:)
217-
s ~ truncated(Normal(3, 1); lower=0)
247+
s ~ prior_dist
218248
return m ~ Normal(0, sqrt(s))
219249
end
220250
alg = NUTS(1000, 0.8; adtype=adbackend)
221251
gdemo_default_prior = DynamicPPL.contextualize(
222252
demo_hmc_prior(), DynamicPPL.PriorContext()
223253
)
224254
chain = sample(gdemo_default_prior, alg, 5_000; initial_params=[3.0, 0.0])
225-
check_numerical(
226-
chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2
227-
)
255+
check_numerical(chain, [:s, :m], [mean(prior_dist), 0]; atol=0.2)
228256
end
229257

230258
@testset "warning for difficult init params" begin
@@ -292,8 +320,8 @@ using Turing
292320

293321
# Extract the `x` like this because running `generated_quantities` was how
294322
# the issue was discovered, hence we also want to make sure that it works.
295-
results = generated_quantities(model, chain)
296-
results_prior = generated_quantities(model, chain_prior)
323+
results = returned(model, chain)
324+
results_prior = returned(model, chain_prior)
297325

298326
# Make sure none of the samples in the chains resulted in errors.
299327
@test all(!isnothing, results)
@@ -315,15 +343,6 @@ using Turing
315343
@test Turing.Inference.getstepsize(spl, hmc_state) isa Float64
316344
end
317345
end
318-
319-
@testset "Check ADType" begin
320-
alg = HMC(0.1, 10; adtype=adbackend)
321-
m = DynamicPPL.contextualize(
322-
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
323-
)
324-
# These will error if the adbackend being used is not the one set.
325-
sample(StableRNG(seed), m, alg, 10)
326-
end
327346
end
328347

329348
end

test/mcmc/sghmc.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@ import Mooncake
1515
using Test: @test, @testset
1616
using Turing
1717

18-
@testset "AD with SGHMC / SGLD" begin
18+
@testset "AD / sghmc.jl" begin
19+
# AD tests need to be run with SamplingContext because samplers can potentially
20+
# use this to define custom behaviour in the tilde-pipeline and thus change the
21+
# code executed during model evaluation.
1922
@testset "adtype=$adtype" for adtype in ADUtils.adbackends
2023
@testset "alg=$alg" for alg in [
2124
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
2225
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
2326
]
27+
@info "Testing AD for $alg"
28+
2429
@testset "model=$(model.f)" for model in DEMO_MODELS
2530
rng = StableRNG(123)
2631
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))

0 commit comments

Comments
 (0)