Skip to content

Commit b924a17

Browse files
sunxd3github-actions[bot]torfjelde
authored
Move the content of ad.jl from Turing.jl to here (#571)
* initialize moving, still need to move tests * Move tests, tests are not fixed yet * Make `ADTypes` a direct dep * Add `ad.jl` for testing * Remove `ADTypes` ext from `require` * Put `ADgradient` code to extensions * Add testing code * Bug fix and adding tests * Update src/simple_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renaming a testset * add require for ReverseDiff extension * fix UUID * fix typo * Also use the original transformation * Fix 1.6 compat * Fix typo * Fix typo, again * Update test/ad.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Fix errors * Refactor the test * Update ad.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * disable Zygote testing * Change testset description * Update test/ad.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde <[email protected]> * Apply Tor's comments --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent c33eeae commit b924a17

10 files changed

+191
-24
lines changed

Project.toml

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.24.6"
3+
version = "0.24.7"
44

55
[deps]
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
67
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
78
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
89
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
@@ -14,6 +15,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1415
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
18+
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1719
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1820
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1921
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -22,19 +24,8 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2224
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2325
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2426

25-
[weakdeps]
26-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
27-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
28-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
29-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
30-
31-
[extensions]
32-
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
33-
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
34-
DynamicPPLMCMCChainsExt = ["MCMCChains"]
35-
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
36-
3727
[compat]
28+
ADTypes = "0.2"
3829
AbstractMCMC = "5"
3930
AbstractPPL = "0.7"
4031
BangBang = "0.3"
@@ -45,20 +36,37 @@ ConstructionBase = "1.5.4"
4536
Distributions = "0.25"
4637
DocStringExtensions = "0.9"
4738
EnzymeCore = "0.6"
39+
LinearAlgebra = "1.6"
4840
LogDensityProblems = "2"
41+
LogDensityProblemsAD = "1.7.0"
4942
MCMCChains = "6"
5043
MacroTools = "0.5.6"
5144
OrderedCollections = "1"
45+
Random = "1.6"
5246
Requires = "1"
5347
Setfield = "1"
54-
ZygoteRules = "0.2"
55-
LinearAlgebra = "1.6"
56-
Random = "1.6"
5748
Test = "1.6"
49+
ZygoteRules = "0.2"
5850
julia = "1.6"
5951

52+
[extensions]
53+
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
54+
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
55+
DynamicPPLForwardDiffExt = ["ForwardDiff"]
56+
DynamicPPLMCMCChainsExt = ["MCMCChains"]
57+
DynamicPPLReverseDiffExt = ["ReverseDiff"]
58+
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
59+
6060
[extras]
6161
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6262
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
6363
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
6464
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
65+
66+
[weakdeps]
67+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
68+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
69+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
70+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
71+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
72+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLForwardDiffExt.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module DynamicPPLForwardDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5+
using ForwardDiff
6+
else
7+
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8+
using ..ForwardDiff
9+
end
10+
11+
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk
12+
13+
standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
14+
standardtag(::ADTypes.AutoForwardDiff) = false
15+
16+
function LogDensityProblemsAD.ADgradient(
17+
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
18+
)
19+
θ = DynamicPPL.getparams(ℓ)
20+
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
21+
22+
# Define configuration for ForwardDiff.
23+
tag = if standardtag(ad)
24+
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
25+
else
26+
ForwardDiff.Tag(f, eltype(θ))
27+
end
28+
chunk_size = getchunksize(ad)
29+
chunk = if chunk_size == 0
30+
ForwardDiff.Chunk(θ)
31+
else
32+
ForwardDiff.Chunk(length(θ), chunk_size)
33+
end
34+
35+
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
36+
end
37+
38+
# Allow Turing tag in gradient etc. calls of the log density function
39+
function ForwardDiff.checktag(
40+
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
41+
::DynamicPPL.LogDensityFunction,
42+
::AbstractArray{W},
43+
) where {V,W}
44+
return true
45+
end
46+
function ForwardDiff.checktag(
47+
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
48+
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
49+
::AbstractArray{W},
50+
) where {V,W}
51+
return true
52+
end
53+
54+
end # module

ext/DynamicPPLReverseDiffExt.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module DynamicPPLReverseDiffExt
2+
3+
if isdefined(Base, :get_extension)
4+
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
5+
using ReverseDiff
6+
else
7+
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
8+
using ..ReverseDiff
9+
end
10+
11+
function LogDensityProblemsAD.ADgradient(
12+
ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction
13+
)
14+
return LogDensityProblemsAD.ADgradient(
15+
Val(:ReverseDiff),
16+
ℓ;
17+
compile=Val(ad.compile),
18+
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
19+
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
20+
# `zero(D)` will return 0 when D is Real.
21+
# here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`.
22+
x=map(identity, DynamicPPL.getparams(ℓ)),
23+
)
24+
end
25+
26+
end # module

src/DynamicPPL.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ using Distributions
88
using OrderedCollections: OrderedDict
99

1010
using AbstractMCMC: AbstractMCMC
11+
using ADTypes: ADTypes
1112
using BangBang: BangBang, push!!, empty!!, setindex!!
1213
using MacroTools: MacroTools
1314
using ConstructionBase: ConstructionBase
1415
using Setfield: Setfield
1516
using LogDensityProblems: LogDensityProblems
17+
using LogDensityProblemsAD: LogDensityProblemsAD
1618

1719
using LinearAlgebra: LinearAlgebra, Cholesky
1820

@@ -189,13 +191,23 @@ end
189191
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
190192
"../ext/DynamicPPLEnzymeCoreExt.jl"
191193
)
194+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
195+
"../ext/DynamicPPLForwardDiffExt.jl"
196+
)
192197
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
193198
"../ext/DynamicPPLMCMCChainsExt.jl"
194199
)
200+
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
201+
"../ext/DynamicPPLReverseDiffExt.jl"
202+
)
195203
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
196204
"../ext/DynamicPPLZygoteRulesExt.jl"
197205
)
198206
end
199207
end
200208

209+
# Standard tag: Improves stacktraces
210+
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
211+
struct DynamicPPLTag end
212+
201213
end # module

src/simple_varinfo.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,12 @@ end
250250

251251
unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x)
252252
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
253-
return Setfield.@set svi.values = unflatten(svi.values, x)
253+
logp = getlogp(svi)
254+
vals = unflatten(svi.values, x)
255+
T = eltype(x)
256+
return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}(
257+
vals, T(logp), svi.transformation
258+
)
254259
end
255260

256261
function BangBang.empty!!(vi::SimpleVarInfo)

src/varinfo.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,7 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
112112
# multiple times.
113113
transformation(vi::VarInfo) = DynamicTransformation()
114114

115-
function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
116-
new_vi = deepcopy(old_vi)
117-
new_vi[spl] = x
118-
return new_vi
119-
end
120-
121-
function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
115+
function VarInfo(old_vi::VarInfo, spl, x::AbstractVector)
122116
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
123117
return VarInfo(
124118
md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))
@@ -147,6 +141,20 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext
147141
return VarInfo(rng, model, SampleFromPrior(), context)
148142
end
149143

144+
# TODO: Remove `space` argument when no longer needed. Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/573
145+
function newmetadata(metadata::Metadata, space, x)
146+
return Metadata(
147+
metadata.idcs,
148+
metadata.vns,
149+
metadata.ranges,
150+
x,
151+
metadata.dists,
152+
metadata.gids,
153+
metadata.orders,
154+
metadata.flags,
155+
)
156+
end
157+
150158
@generated function newmetadata(
151159
metadata::NamedTuple{names}, ::Val{space}, x
152160
) where {names,space}

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
34
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
45
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
@@ -11,10 +12,12 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1112
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
15+
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1416
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1517
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1618
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1719
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
20+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1821
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1922
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2023
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -23,6 +26,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2326
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2427

2528
[compat]
29+
ADTypes = "0.2"
2630
AbstractMCMC = "5"
2731
AbstractPPL = "0.7"
2832
Bijectors = "0.13"

test/ad.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
@testset "AD: ForwardDiff and ReverseDiff" begin
2+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
3+
f = DynamicPPL.LogDensityFunction(m)
4+
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
5+
vns = DynamicPPL.TestUtils.varnames(m)
6+
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
7+
8+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
9+
f = DynamicPPL.LogDensityFunction(m, varinfo)
10+
11+
# use ForwardDiff result as reference
12+
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
13+
ADTypes.AutoForwardDiff(; chunksize=0), f
14+
)
15+
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
16+
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
17+
θ = convert(Vector{Float64}, varinfo[:])
18+
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
19+
20+
@testset "ReverseDiff with compile=$compile" for compile in (false, true)
21+
adtype = ADTypes.AutoReverseDiff(; compile=compile)
22+
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
23+
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
24+
@test grad ref_grad
25+
end
26+
end
27+
end
28+
end

test/ext/DynamicPPLForwardDiffExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
@testset "tag" begin
2+
for chunksize in (0, 1, 10)
3+
ad = ADTypes.AutoForwardDiff(; chunksize=chunksize)
4+
standardtag = if !isdefined(Base, :get_extension)
5+
DynamicPPL.DynamicPPLForwardDiffExt.standardtag
6+
else
7+
Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag
8+
end
9+
@test standardtag(ad)
10+
for tag in (false, 0, 1)
11+
@test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag))
12+
end
13+
end
14+
end

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using ADTypes
12
using DynamicPPL
23
using AbstractMCMC
34
using AbstractPPL
@@ -6,9 +7,11 @@ using Distributions
67
using DistributionsAD
78
using Documenter
89
using ForwardDiff
10+
using LogDensityProblems, LogDensityProblemsAD
911
using MacroTools
1012
using MCMCChains
1113
using Tracker
14+
using ReverseDiff
1215
using Zygote
1316
using Setfield
1417
using Compat
@@ -64,6 +67,11 @@ include("test_util.jl")
6467
include("ext/DynamicPPLMCMCChainsExt.jl")
6568
end
6669

70+
@testset "ad" begin
71+
include("ext/DynamicPPLForwardDiffExt.jl")
72+
include("ad.jl")
73+
end
74+
6775
@testset "doctests" begin
6876
DocMeta.setdocmeta!(
6977
DynamicPPL,

0 commit comments

Comments
 (0)