Skip to content

Commit 754b190

Browse files
authored
Re-enable Mooncake in tests (#1135)
* Re-enable Mooncake * fixes * fix import
1 parent 08fffa2 commit 754b190

File tree

3 files changed

+10
-21
lines changed

3 files changed

+10
-21
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1919
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2020
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2121
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
22+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2223
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2324
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2425
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/ad.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
55
# Used as the ground truth that others are compared against.
66
ref_adtype = AutoForwardDiff()
77

8-
test_adtypes = if MOONCAKE_SUPPORTED
9-
[
10-
AutoReverseDiff(; compile=false),
11-
AutoReverseDiff(; compile=true),
12-
AutoMooncake(; config=nothing),
13-
]
14-
else
15-
[AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)]
16-
end
8+
test_adtypes = [
9+
AutoReverseDiff(; compile=false),
10+
AutoReverseDiff(; compile=true),
11+
AutoMooncake(; config=nothing),
12+
]
1713

1814
@testset "Unsupported backends" begin
1915
@model demo() = x ~ Normal()
@@ -43,13 +39,13 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
4339
# Put predicates here to avoid long lines
4440
is_mooncake = adtype isa AutoMooncake
4541
is_1_10 = v"1.10" <= VERSION < v"1.11"
46-
is_1_11 = v"1.11" <= VERSION < v"1.12"
42+
is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13"
4743
is_svi_vnv =
4844
linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
4945
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}
5046

5147
# Mooncake doesn't work with several combinations of SimpleVarInfo.
52-
if is_mooncake && is_1_11 && is_svi_vnv
48+
if is_mooncake && is_1_11_or_1_12 && is_svi_vnv
5349
# https://github.com/compintell/Mooncake.jl/issues/470
5450
@test_throws ArgumentError DynamicPPL.LogDensityFunction(
5551
m, getlogjoint_internal, linked_varinfo; adtype=adtype

test/runtests.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using MacroTools
1515
using MCMCChains
1616
using StableRNGs
1717
using ReverseDiff
18+
using Mooncake
1819
using Zygote
1920

2021
using Distributed
@@ -37,13 +38,6 @@ using DynamicPPL: getargs_dottilde, getargs_tilde
3738
const GROUP = get(ENV, "GROUP", "All")
3839
const AQUA = get(ENV, "AQUA", "true") == "true"
3940

40-
# Skip Mooncake if it doesn't work
41-
const MOONCAKE_SUPPORTED = VERSION < v"1.12.0"
42-
if MOONCAKE_SUPPORTED
43-
Pkg.add("Mooncake")
44-
using Mooncake: Mooncake
45-
end
46-
4741
Random.seed!(100)
4842
include("test_util.jl")
4943

@@ -85,9 +79,7 @@ include("test_util.jl")
8579
end
8680
@testset "ad" begin
8781
include("ext/DynamicPPLForwardDiffExt.jl")
88-
if MOONCAKE_SUPPORTED
89-
include("ext/DynamicPPLMooncakeExt.jl")
90-
end
82+
include("ext/DynamicPPLMooncakeExt.jl")
9183
include("ad.jl")
9284
end
9385
@testset "prob and logprob macro" begin

0 commit comments

Comments
 (0)