diff --git a/Project.toml b/Project.toml index f7be0257d..77ce00c26 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.29" +version = "0.29.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -21,6 +21,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -30,6 +31,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -38,6 +40,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLTapirExt = ["Tapir"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -64,13 +67,14 @@ Random = "1.6" Requires = "1" ReverseDiff = "1" Test = "1.6" +Tapir = "0.2.40" ZygoteRules = "0.2" julia = "1.6" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLTapirExt.jl b/ext/DynamicPPLTapirExt.jl new file mode 100644 index 000000000..8546d5163 --- /dev/null +++ b/ext/DynamicPPLTapirExt.jl @@ -0,0 +1,35 @@ +module DynamicPPLTapirExt + +if isdefined(Base, :get_extension) + using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using Tapir +else + using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ..Tapir +end + +function DynamicPPL.setmodel( + f::LogDensityProblemsAD.ADGradientWrapper, + model::DynamicPPL.Model, + adtype::ADTypes.AutoTapir, +) + if !hasfield(typeof(f), :rule) + @warn "ADGradientWrapper does not have a `rule` field. Please check Tapir version. It is also possible that `adtype` mismatch `ADGradientWrapper` type." + @warn "Using default rule." + return LogDensityProblemsAD.ADgradient( + Val(:Tapir), + DynamicPPL.setmodel(LogDensityProblemsAD.parent(f), model); + safety_on=adtype.safe_mode, + rule=nothing, + ) + else + return LogDensityProblemsAD.ADgradient( + Val(:Tapir), + DynamicPPL.setmodel(LogDensityProblemsAD.parent(f), model); + safety_on=adtype.safe_mode, + rule=f.rule, + ) + end +end + +end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..e103b9e03 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -214,6 +214,9 @@ end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( "../ext/DynamicPPLReverseDiffExt.jl" ) + @require Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" include( + "../ext/DynamicPPLTapirExt.jl" + ) @require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include( "../ext/DynamicPPLZygoteRulesExt.jl" )