Skip to content

Commit 397d1a7

Browse files
authored
Move ADTypeCheckContext tests to a separate module (#2383)
1 parent 5426eca commit 397d1a7

File tree

3 files changed

+54
-38
lines changed

3 files changed

+54
-38
lines changed

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ macro timeit_include(path::AbstractString)
3131
end
3232

3333
@testset "Turing" begin
34+
@testset "Test utils" begin
35+
@timeit_include("test_utils/test_utils.jl")
36+
end
37+
3438
@testset "Aqua" begin
3539
@timeit_include("Aqua.jl")
3640
end

test/test_utils/ad_utils.jl

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -229,44 +229,6 @@ function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, righ
229229
return logp, vi
230230
end
231231

232-
# Check that the ADTypeCheckContext works as expected.
233-
Test.@testset "ADTypeCheckContext" begin
234-
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
235-
tm = test_model()
236-
adtypes = (
237-
Turing.AutoForwardDiff(),
238-
Turing.AutoReverseDiff(),
239-
Turing.AutoZygote(),
240-
# TODO: Mooncake
241-
# Turing.AutoMooncake(config=nothing),
242-
)
243-
for actual_adtype in adtypes
244-
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
245-
for expected_adtype in adtypes
246-
if (
247-
actual_adtype == Turing.AutoForwardDiff() &&
248-
expected_adtype == Turing.AutoZygote()
249-
)
250-
# TODO(mhauru) We are currently unable to check this case.
251-
continue
252-
end
253-
contextualised_tm = DynamicPPL.contextualize(
254-
tm, ADTypeCheckContext(expected_adtype, tm.context)
255-
)
256-
Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
257-
if actual_adtype == expected_adtype
258-
# Check that this does not throw an error.
259-
Turing.sample(contextualised_tm, sampler, 2)
260-
else
261-
Test.@test_throws AbstractWrongADBackendError Turing.sample(
262-
contextualised_tm, sampler, 2
263-
)
264-
end
265-
end
266-
end
267-
end
268-
end
269-
270232
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
271233
# List of AD backends to test.
272234

test/test_utils/test_utils.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Module for testing the test utils themselves."""
2+
module TestUtilsTests
3+
4+
using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError
5+
using ForwardDiff: ForwardDiff
6+
using ReverseDiff: ReverseDiff
7+
using Test: @test, @testset, @test_throws
8+
using Turing: Turing
9+
using Turing: DynamicPPL
10+
using Zygote: Zygote
11+
12+
# Check that the ADTypeCheckContext works as expected.
13+
@testset "ADTypeCheckContext" begin
14+
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
15+
tm = test_model()
16+
adtypes = (
17+
Turing.AutoForwardDiff(),
18+
Turing.AutoReverseDiff(),
19+
Turing.AutoZygote(),
20+
# TODO: Mooncake
21+
# Turing.AutoMooncake(config=nothing),
22+
)
23+
for actual_adtype in adtypes
24+
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
25+
for expected_adtype in adtypes
26+
if (
27+
actual_adtype == Turing.AutoForwardDiff() &&
28+
expected_adtype == Turing.AutoZygote()
29+
)
30+
# TODO(mhauru) We are currently unable to check this case.
31+
continue
32+
end
33+
contextualised_tm = DynamicPPL.contextualize(
34+
tm, ADTypeCheckContext(expected_adtype, tm.context)
35+
)
36+
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
37+
if actual_adtype == expected_adtype
38+
# Check that this does not throw an error.
39+
Turing.sample(contextualised_tm, sampler, 2)
40+
else
41+
@test_throws AbstractWrongADBackendError Turing.sample(
42+
contextualised_tm, sampler, 2
43+
)
44+
end
45+
end
46+
end
47+
end
48+
end
49+
50+
end

0 commit comments

Comments
 (0)