Skip to content

Commit d8f6dfa

Browse files
committed
Remove AD backend loops
1 parent f184d3f commit d8f6dfa

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

test/mcmc/abstractmcmc.jl

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -112,36 +112,36 @@ end
112112

113113
@testset "External samplers" begin
114114
@testset "AdvancedHMC.jl" begin
115-
@testset "adtype=$adtype" for adtype in ADUtils.adbackends
116-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
117-
# Need some functionality to initialize the sampler.
118-
# TODO: Remove this once the constructors in the respective packages become "lazy".
119-
sampler = initialize_nuts(model)
120-
sampler_ext = DynamicPPL.Sampler(
121-
externalsampler(sampler; adtype, unconstrained=true)
115+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
116+
adtype = Turing.DEFAULT_ADTYPE
117+
118+
# Need some functionality to initialize the sampler.
119+
# TODO: Remove this once the constructors in the respective packages become "lazy".
120+
sampler = initialize_nuts(model)
121+
sampler_ext = DynamicPPL.Sampler(
122+
externalsampler(sampler; adtype, unconstrained=true)
123+
)
124+
# FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment.
125+
# @testset "initial_params" begin
126+
# test_initial_params(model, sampler_ext; n_adapts=0)
127+
# end
128+
129+
sample_kwargs = (
130+
n_adapts=1_000,
131+
discard_initial=1_000,
132+
# FIXME: Remove this once we can run `test_initial_params` above.
133+
initial_params=DynamicPPL.VarInfo(model)[:],
134+
)
135+
136+
@testset "inference" begin
137+
DynamicPPL.TestUtils.test_sampler(
138+
[model],
139+
sampler_ext,
140+
2_000;
141+
rtol=0.2,
142+
sampler_name="AdvancedHMC",
143+
sample_kwargs...,
122144
)
123-
# FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment.
124-
# @testset "initial_params" begin
125-
# test_initial_params(model, sampler_ext; n_adapts=0)
126-
# end
127-
128-
sample_kwargs = (
129-
n_adapts=1_000,
130-
discard_initial=1_000,
131-
# FIXME: Remove this once we can run `test_initial_params` above.
132-
initial_params=DynamicPPL.VarInfo(model)[:],
133-
)
134-
135-
@testset "inference" begin
136-
DynamicPPL.TestUtils.test_sampler(
137-
[model],
138-
sampler_ext,
139-
2_000;
140-
rtol=0.2,
141-
sampler_name="AdvancedHMC",
142-
sample_kwargs...,
143-
)
144-
end
145145
end
146146
end
147147
end

test/mcmc/sghmc.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,35 @@ import Mooncake
1212
using Test: @test, @testset
1313
using Turing
1414

15-
@testset "Testing sghmc.jl with $adbackend" for adbackend in ADUtils.adbackends
15+
@testset "Testing sghmc.jl" begin
1616
@testset "sghmc constructor" begin
17-
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
17+
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE)
1818
@test alg isa SGHMC
1919
sampler = Turing.Sampler(alg)
2020
@test sampler isa Turing.Sampler{<:SGHMC}
2121

22-
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
22+
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE)
2323
@test alg isa SGHMC
2424
sampler = Turing.Sampler(alg)
2525
@test sampler isa Turing.Sampler{<:SGHMC}
2626
end
2727
@testset "sghmc inference" begin
2828
rng = StableRNG(123)
2929

30-
alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adbackend)
30+
alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=Turing.DEFAULT_ADTYPE)
3131
chain = sample(rng, gdemo_default, alg, 10_000)
3232
check_gdemo(chain; atol=0.1)
3333
end
3434
end
3535

36-
@testset "Testing sgld.jl with $adbackend" for adbackend in ADUtils.adbackends
36+
@testset "Testing sgld.jl" begin
3737
@testset "sgld constructor" begin
38-
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
38+
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=Turing.DEFAULT_ADTYPE)
3939
@test alg isa SGLD
4040
sampler = Turing.Sampler(alg)
4141
@test sampler isa Turing.Sampler{<:SGLD}
4242

43-
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
43+
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=Turing.DEFAULT_ADTYPE)
4444
@test alg isa SGLD
4545
sampler = Turing.Sampler(alg)
4646
@test sampler isa Turing.Sampler{<:SGLD}

0 commit comments

Comments
 (0)