diff --git a/Project.toml b/Project.toml index 909be870f..21d501e9f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.3" +version = "0.31.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -30,6 +30,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -37,6 +38,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMooncakeExt = ["Mooncake"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -58,6 +60,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" +Mooncake = "0.4.59" OrderedCollections = "1" Random = "1.6" Requires = "1" diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl new file mode 100644 index 000000000..b86d807bc --- /dev/null +++ b/ext/DynamicPPLMooncakeExt.jl @@ -0,0 +1,9 @@ +module DynamicPPLMooncakeExt + +using DynamicPPL: DynamicPPL, istrans +using Mooncake: Mooncake + +# This is purely an optimisation. +Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} + +end # module diff --git a/test/Project.toml b/test/Project.toml index 0d247c3ec..4f12f2015 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -45,7 +45,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" -Mooncake = "0.4.50" +Mooncake = "0.4.59" ReverseDiff = "1" StableRNGs = "1" Tracker = "0.2.23" diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl new file mode 100644 index 000000000..986057da0 --- /dev/null +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -0,0 +1,5 @@ +@testset "DynamicPPLMooncakeExt" begin + Mooncake.TestUtils.test_rule( + StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index dbfa319b0..a4fdabf22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using LogDensityProblems, LogDensityProblemsAD using MacroTools using MCMCChains using Mooncake: Mooncake +using StableRNGs using Tracker using ReverseDiff using Zygote @@ -77,6 +78,7 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") + include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end