Skip to content

Commit 4a234d6

Browse files
committed
Internal _make_ad_gradient
1 parent 46bbf06 commit 4a234d6

File tree

4 files changed

+16
-23
lines changed

4 files changed

+16
-23
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3030
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3334
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3435
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3536

@@ -38,6 +39,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
3839
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3940
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4041
DynamicPPLMCMCChainsExt = ["MCMCChains"]
42+
DynamicPPLMooncakeExt = ["Mooncake"]
4143
DynamicPPLReverseDiffExt = ["ReverseDiff"]
4244
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4345

@@ -60,10 +62,11 @@ LogDensityProblems = "2"
6062
LogDensityProblemsAD = "1.7.0"
6163
MCMCChains = "6"
6264
MacroTools = "0.5.6"
65+
Mooncake = "0.4.52"
6366
OrderedCollections = "1"
6467
Random = "1.6"
65-
Requires = "1"
6668
ReverseDiff = "1"
69+
Requires = "1"
6770
Test = "1.6"
6871
ZygoteRules = "0.2"
6972
julia = "1.10"

ext/DynamicPPLMooncakeExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module DynamicPPLMooncakeExt
2+
3+
import LogDensityProblemsAD: ADgradient
4+
using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction
5+
6+
ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) = _make_ad_gradient(ad, f)
7+
8+
end # module

ext/DynamicPPLReverseDiffExt.jl

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,8 @@
11
module DynamicPPLReverseDiffExt
22

3-
if isdefined(Base, :get_extension)
4-
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5-
using ReverseDiff
6-
else
7-
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8-
using ..ReverseDiff
9-
end
3+
import LogDensityProblemsAD: ADgradient
4+
using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction
105

11-
function LogDensityProblemsAD.ADgradient(
12-
ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction
13-
) where {Tcompile}
14-
return LogDensityProblemsAD.ADgradient(
15-
Val(:ReverseDiff),
16-
ℓ;
17-
compile=Val(Tcompile),
18-
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
19-
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
20-
# `zero(D)` will return 0 when D is Real.
21-
# here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`.
22-
x=map(identity, DynamicPPL.getparams(ℓ)),
23-
)
24-
end
6+
ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) = _make_ad_gradient(ad, f)
257

268
end # module

src/logdensityfunction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
149149
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
150150
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
151151
# a tape when using ReverseDiff.jl.
152-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
152+
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
153153
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
154154
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
155155
end

0 commit comments

Comments
 (0)