Skip to content

Commit a0db647

Browse files
authored
Use new style kwarg constructor for AutoReverseDiff (#2273)
1 parent 7b2869f commit a0db647

File tree

7 files changed

+14
-12
lines changed

7 files changed

+14
-12
lines changed

benchmarks/benchmarks_suite.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,5 @@ BenchmarkSuite["mnormal"]["forwarddiff"] = @benchmarkable sample(
8484

8585
# ReverseDiff
8686
BenchmarkSuite["mnormal"]["reversediff"] = @benchmarkable sample(
87-
$(mdemo(d, 1)), $(HMC(0.1, 5; adtype=AutoReverseDiff(false))), 5000
87+
$(mdemo(d, 1)), $(HMC(0.1, 5; adtype=AutoReverseDiff(; compile=false))), 5000
8888
)

test/essential/ad.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,22 +154,22 @@ end
154154
return theta ~ Dirichlet(1 ./ fill(4, 4))
155155
end
156156
sample(dir(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
157-
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
158-
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(true)), 1000)
157+
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
158+
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=true)), 1000)
159159
end
160160
@testset "PDMatDistribution AD" begin
161161
@model function wishart()
162162
return theta ~ Wishart(4, Matrix{Float64}(I, 4, 4))
163163
end
164164

165-
sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
165+
sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
166166
sample(wishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
167167

168168
@model function invwishart()
169169
return theta ~ InverseWishart(4, Matrix{Float64}(I, 4, 4))
170170
end
171171

172-
sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
172+
sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
173173
sample(invwishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
174174
end
175175
@testset "Hessian test" begin
@@ -231,7 +231,9 @@ end
231231
for i in 1:5
232232
d = Normal(0.0, i)
233233
data = rand(d, N)
234-
chn = sample(demo(data), NUTS(0.65; adtype=AutoReverseDiff(true)), 1000)
234+
chn = sample(
235+
demo(data), NUTS(0.65; adtype=AutoReverseDiff(; compile=true)), 1000
236+
)
235237
@test mean(Array(chn[:sigma])) std(data) atol = 0.5
236238
end
237239
end

test/mcmc/Inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import ReverseDiff
1414
using Test: @test, @test_throws, @testset
1515
using Turing
1616

17-
@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
17+
@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
1818
# Only test threading if 1.3+.
1919
if VERSION > v"1.2"
2020
@testset "threaded sampling" begin

test/mcmc/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using Turing: Inference
1313
using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
1414

1515
@testset "Testing gibbs.jl with $adbackend" for adbackend in (
16-
AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)
16+
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
1717
)
1818
@testset "gibbs constructor" begin
1919
N = 500

test/mcmc/gibbs_conditional.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Test: @test, @testset
1515
using Turing
1616

1717
@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in (
18-
AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)
18+
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
1919
)
2020
Random.seed!(1000)
2121
rng = StableRNG(123)

test/mcmc/hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using StatsFuns: logistic
1616
using Test: @test, @test_logs, @testset
1717
using Turing
1818

19-
@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
19+
@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
2020
# Set a seed
2121
rng = StableRNG(123)
2222
@testset "constrained bounded" begin

test/mcmc/sghmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using StableRNGs: StableRNG
1010
using Test: @test, @testset
1111
using Turing
1212

13-
@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
13+
@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
1414
@testset "sghmc constructor" begin
1515
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
1616
@test alg isa SGHMC
@@ -36,7 +36,7 @@ using Turing
3636
end
3737
end
3838

39-
@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
39+
@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
4040
@testset "sgld constructor" begin
4141
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
4242
@test alg isa SGLD

0 commit comments

Comments
 (0)