Skip to content

Commit f92c93f

Browse files
mhauruyebaisunxd3willtebbutt
authored
Test with Tapir (#2289)
* Test with Tapir * Relax Tapir version bound * Relax Tapir version bounds more * Add test/test_utils/ad_utils.jl * Change how Tapir is installed for tests * Typo fix * Turn Tapir's safe mode off * Use standard AutoReverseDiff constructor Co-authored-by: Hong Ge <[email protected]> * Revert back to previous AutoReverseDiff constructor * modify `setvarinfo` * fix test error * fix more error * fix error * fix error * Exclude Tapir from AdvancedHMC tests Co-authored-by: Will Tebbutt <[email protected]> * Update ad_utils.jl (#2313) * Update test/test_utils/ad_utils.jl * Move code around in ad_utils.jl * Add a todo note --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Will Tebbutt <[email protected]>
1 parent a26ce11 commit f92c93f

File tree

8 files changed

+65
-17
lines changed

8 files changed

+65
-17
lines changed

src/mcmc/abstractmcmc.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,17 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat
4040
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
4141

4242
getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
43-
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))
43+
function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper)
44+
return getvarinfo(LogDensityProblemsAD.parent(f))
45+
end
4446

4547
setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
46-
function setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo)
47-
return Accessors.@set f.= setvarinfo(f.ℓ, varinfo)
48+
function setvarinfo(
49+
f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType
50+
)
51+
return LogDensityProblemsAD.ADgradient(
52+
adtype, setvarinfo(LogDensityProblemsAD.parent(f), varinfo)
53+
)
4854
end
4955

5056
"""
@@ -120,7 +126,7 @@ function AbstractMCMC.step(
120126
varinfo = DynamicPPL.link(varinfo, model)
121127
end
122128
end
123-
f = setvarinfo(f, varinfo)
129+
f = setvarinfo(f, varinfo, alg.adtype)
124130

125131
# Then just call `AdvancedHMC.step` with the right arguments.
126132
if initial_state === nothing

test/mcmc/Inference.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module InferenceTests
22

33
using ..Models: gdemo_d, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
5+
import ..ADUtils
56
using Distributions: Bernoulli, Beta, InverseGamma, Normal
67
using Distributions: sample
78
import DynamicPPL
@@ -14,7 +15,9 @@ import ReverseDiff
1415
using Test: @test, @test_throws, @testset
1516
using Turing
1617

17-
@testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
18+
ADUtils.install_tapir && import Tapir
19+
20+
@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
1821
# Only test threading if 1.3+.
1922
if VERSION > v"1.2"
2023
@testset "threaded sampling" begin

test/mcmc/abstractmcmc.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module AbstractMCMCTests
22

3+
import ..ADUtils
34
using AdvancedMH: AdvancedMH
45
using Distributions: sample
56
using Distributions.FillArrays: Zeros
@@ -15,14 +16,18 @@ using Test: @test, @test_throws, @testset
1516
using Turing
1617
using Turing.Inference: AdvancedHMC
1718

19+
ADUtils.install_tapir && import Tapir
20+
1821
function initialize_nuts(model::Turing.Model)
1922
# Create a log-density function with an implementation of the
2023
# gradient so we ensure that we're using the same AD backend as in Turing.
2124
f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model))
2225

2326
# Link the varinfo.
2427
f = Turing.Inference.setvarinfo(
25-
f, DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model)
28+
f,
29+
DynamicPPL.link!!(Turing.Inference.getvarinfo(f), model),
30+
Turing.Inference.getADType(DynamicPPL.getcontext(LogDensityProblemsAD.parent(f))),
2631
)
2732

2833
# Choose parameter dimensionality and initial parameter value
@@ -112,7 +117,9 @@ end
112117

113118
@testset "External samplers" begin
114119
@testset "AdvancedHMC.jl" begin
115-
# Try a few different AD backends.
120+
# TODO(mhauru) The below tests fail with Tapir, see
121+
# https://github.com/TuringLang/Turing.jl/pull/2289.
122+
# Once that is fixed, this should say `for adtype in ADUtils.adbackends`.
116123
@testset "adtype=$adtype" for adtype in [AutoForwardDiff(), AutoReverseDiff()]
117124
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
118125
# Need some functionality to initialize the sampler.

test/mcmc/gibbs.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GibbsTests
22

33
using ..Models: MoGtest_default, gdemo, gdemo_default
44
using ..NumericalTests: check_MoGtest_default, check_gdemo, check_numerical
5+
import ..ADUtils
56
using Distributions: InverseGamma, Normal
67
using Distributions: sample
78
using ForwardDiff: ForwardDiff
@@ -12,9 +13,9 @@ using Turing
1213
using Turing: Inference
1314
using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
1415

15-
@testset "Testing gibbs.jl with $adbackend" for adbackend in (
16-
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
17-
)
16+
ADUtils.install_tapir && import Tapir
17+
18+
@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
1819
@testset "gibbs constructor" begin
1920
N = 500
2021
s1 = Gibbs(HMC(0.1, 5, :s, :m; adtype=adbackend))

test/mcmc/gibbs_conditional.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GibbsConditionalTests
22

33
using ..Models: gdemo, gdemo_default
44
using ..NumericalTests: check_gdemo, check_numerical
5+
import ..ADUtils
56
using Clustering: Clustering
67
using Distributions: Categorical, InverseGamma, Normal, sample
78
using ForwardDiff: ForwardDiff
@@ -14,9 +15,9 @@ using StatsFuns: StatsFuns
1415
using Test: @test, @testset
1516
using Turing
1617

17-
@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in (
18-
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
19-
)
18+
ADUtils.install_tapir && import Tapir
19+
20+
@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in ADUtils.adbackends
2021
Random.seed!(1000)
2122
rng = StableRNG(123)
2223

test/mcmc/hmc.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ..Models: gdemo_default
44
using ..ADUtils: ADTypeCheckContext
55
#using ..Models: gdemo
66
using ..NumericalTests: check_gdemo, check_numerical
7+
import ..ADUtils
78
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
89
import DynamicPPL
910
using DynamicPPL: Sampler
@@ -17,7 +18,9 @@ using StatsFuns: logistic
1718
using Test: @test, @test_logs, @testset
1819
using Turing
1920

20-
@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
21+
ADUtils.install_tapir && import Tapir
22+
23+
@testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends
2124
# Set a seed
2225
rng = StableRNG(123)
2326
@testset "constrained bounded" begin

test/mcmc/sghmc.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module SGHMCTests
22

33
using ..Models: gdemo_default
44
using ..NumericalTests: check_gdemo
5+
import ..ADUtils
56
using Distributions: sample
67
import ForwardDiff
78
using LinearAlgebra: dot
@@ -10,7 +11,9 @@ using StableRNGs: StableRNG
1011
using Test: @test, @testset
1112
using Turing
1213

13-
@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
14+
ADUtils.install_tapir && import Tapir
15+
16+
@testset "Testing sghmc.jl with $adbackend" for adbackend in ADUtils.adbackends
1417
@testset "sghmc constructor" begin
1518
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
1619
@test alg isa SGHMC
@@ -36,7 +39,7 @@ using Turing
3639
end
3740
end
3841

39-
@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
42+
@testset "Testing sgld.jl with $adbackend" for adbackend in ADUtils.adbackends
4043
@testset "sgld constructor" begin
4144
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
4245
@test alg isa SGLD

test/test_utils/ad_utils.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ADUtils
22

33
using ForwardDiff: ForwardDiff
4+
using Pkg: Pkg
45
using Random: Random
56
using ReverseDiff: ReverseDiff
67
using Test: Test
@@ -9,7 +10,10 @@ using Turing: Turing
910
using Turing: DynamicPPL
1011
using Zygote: Zygote
1112

12-
export ADTypeCheckContext
13+
export ADTypeCheckContext, adbackends
14+
15+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
16+
# Stuff for checking that the right AD backend is being used.
1317

1418
"""Element types that are always valid for a VarInfo regardless of ADType."""
1519
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
@@ -270,4 +274,24 @@ Test.@testset "ADTypeCheckContext" begin
270274
end
271275
end
272276

277+
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
278+
# List of AD backends to test.
279+
280+
"""
281+
All the ADTypes on which we want to run the tests.
282+
"""
283+
adbackends = [
284+
Turing.AutoForwardDiff(; chunksize=0), Turing.AutoReverseDiff(; compile=false)
285+
]
286+
287+
# Tapir isn't supported for older Julia versions, hence the check.
288+
install_tapir = isdefined(Turing, :AutoTapir)
289+
if install_tapir
290+
# TODO(mhauru) Is there a better way to install optional dependencies like this?
291+
Pkg.add("Tapir")
292+
using Tapir
293+
push!(adbackends, Turing.AutoTapir(false))
294+
push!(eltypes_by_adtype, Turing.AutoTapir => (Tapir.CoDual,))
295+
end
296+
273297
end

0 commit comments

Comments
 (0)