Skip to content

Commit 4a17e82

Browse files
Tor FjeldeTor Fjelde
authored andcommitted
Merge branch 'master' into torfjelde/determine-varinfo
2 parents 325c5f9 + 0548ddf commit 4a17e82

12 files changed

+186
-50
lines changed

Project.toml

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.31.1"
3+
version = "0.31.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -31,7 +31,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3333
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
34-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
34+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3535
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3636

3737
[extensions]
@@ -40,7 +40,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
4040
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4141
DynamicPPLJETExt = ["JET"]
4242
DynamicPPLMCMCChainsExt = ["MCMCChains"]
43-
DynamicPPLReverseDiffExt = ["ReverseDiff"]
43+
DynamicPPLMooncakeExt = ["Mooncake"]
4444
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4545

4646
[compat]
@@ -63,19 +63,10 @@ LogDensityProblems = "2"
6363
LogDensityProblemsAD = "1.7.0"
6464
MCMCChains = "6"
6565
MacroTools = "0.5.6"
66+
Mooncake = "0.4.59"
6667
OrderedCollections = "1"
6768
Random = "1.6"
6869
Requires = "1"
69-
ReverseDiff = "1"
7070
Test = "1.6"
7171
ZygoteRules = "0.2"
7272
julia = "1.10"
73-
74-
[extras]
75-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
76-
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
77-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
78-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
79-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
80-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
81-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLMooncakeExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module DynamicPPLMooncakeExt
2+
3+
using DynamicPPL: DynamicPPL, istrans
4+
using Mooncake: Mooncake
5+
6+
# This is purely an optimisation.
7+
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
8+
9+
end # module

ext/DynamicPPLReverseDiffExt.jl

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/logdensityfunction.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,19 @@ function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
144144
end
145145
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
146146
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
147+
148+
# This is important for performance -- one needs to provide `ADGradient` with a vector of
149+
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
150+
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
151+
# a tape when using ReverseDiff.jl.
152+
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
153+
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
154+
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
155+
end
156+
157+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
158+
return _make_ad_gradient(ad, f)
159+
end
160+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
161+
return _make_ad_gradient(ad, f)
162+
end

src/threadsafe.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:
178178
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns)
179179
end
180180

181+
vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
182+
vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn)
183+
function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
184+
return vector_getranges(vi.varinfo, vns)
185+
end
186+
181187
function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
182188
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
183189
end

src/varinfo.jl

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,15 @@ function VarInfo(
208208
end
209209
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
210210

211+
"""
212+
vector_length(varinfo::VarInfo)
213+
214+
Return the length of the vector representation of `varinfo`.
215+
"""
216+
vector_length(varinfo::VarInfo) = length(varinfo.metadata)
217+
vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata)
218+
vector_length(md::Metadata) = sum(length, md.ranges)
219+
211220
unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x)
212221

213222
# TODO: deprecate.
@@ -632,7 +641,72 @@ setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range
632641
Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
633642
"""
634643
function getranges(vi::VarInfo, vns::Vector{<:VarName})
635-
return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[])
644+
return map(Base.Fix1(getrange, vi), vns)
645+
end
646+
647+
"""
648+
vector_getrange(varinfo::VarInfo, varname::VarName)
649+
650+
Return the range corresponding to `varname` in the vector representation of `varinfo`.
651+
"""
652+
vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn)
653+
function vector_getrange(vi::TypedVarInfo, vn::VarName)
654+
offset = 0
655+
for md in values(vi.metadata)
656+
# First, we need to check if `vn` is in `md`.
657+
# In this case, we can just return the corresponding range + offset.
658+
haskey(md, vn) && return getrange(md, vn) .+ offset
659+
# Otherwise, we need to get the cumulative length of the ranges in `md`
660+
# and add it to the offset.
661+
offset += sum(length, md.ranges)
662+
end
663+
# If we reach this point, `vn` is not in `vi.metadata`.
664+
throw(KeyError(vn))
665+
end
666+
667+
"""
668+
vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})
669+
670+
Return the range corresponding to `varname` in the vector representation of `varinfo`.
671+
"""
672+
function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName})
673+
return map(Base.Fix1(vector_getrange, varinfo), varname)
674+
end
675+
# Specialized version for `TypedVarInfo`.
676+
function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName})
677+
# TODO: Does it help if we _don't_ convert to a vector here?
678+
metadatas = collect(values(varinfo.metadata))
679+
# Extract the offsets.
680+
offsets = cumsum(map(vector_length, metadatas))
681+
# Extract the ranges from each metadata.
682+
ranges = Vector{UnitRange{Int}}(undef, length(vns))
683+
# Need to keep track of which ones we've seen.
684+
not_seen = fill(true, length(vns))
685+
for (i, metadata) in enumerate(metadatas)
686+
vns_metadata = filter(Base.Fix1(haskey, metadata), vns)
687+
# If none of the variables exist in the metadata, we return an empty array.
688+
isempty(vns_metadata) && continue
689+
# Otherwise, we extract the ranges.
690+
offset = i == 1 ? 0 : offsets[i - 1]
691+
for vn in vns_metadata
692+
r_vn = getrange(metadata, vn)
693+
# Get the index, so we return in the same order as `vns`.
694+
# NOTE: There might be duplicates in `vns`, so we need to handle that.
695+
indices = findall(==(vn), vns)
696+
for idx in indices
697+
not_seen[idx] = false
698+
ranges[idx] = r_vn .+ offset
699+
end
700+
end
701+
end
702+
# Raise key error if any of the variables were not found.
703+
if any(not_seen)
704+
inds = findall(not_seen)
705+
# Just use a `convert` to get the same type as the input; don't want to confuse by overly
706+
# specilizing the types in the error message.
707+
throw(KeyError(convert(typeof(vns), vns[inds])))
708+
end
709+
return ranges
636710
end
637711

638712
"""
@@ -1320,13 +1394,13 @@ end
13201394

13211395
function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f)
13221396
# TODO: Use inplace versions to avoid allocations
1323-
yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn))
1397+
yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn))
13241398
# Determine the new range.
1325-
start = first(getrange(vi, vn))
1399+
start = first(getrange(md, vn))
13261400
# NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`.
1327-
setrange!(vi, vn, start:(start + length(yvec) - 1))
1401+
setrange!(md, vn, start:(start + length(yvec) - 1))
13281402
# Set the new value.
1329-
setval!(vi, yvec, vn)
1403+
setval!(md, yvec, vn)
13301404
acclogp!!(vi, -logjac)
13311405
return vi
13321406
end

src/varnamedvector.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,8 @@ function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {spa
10361036
return replace_raw_storage(vnv, vals)
10371037
end
10381038

1039+
vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv)
1040+
10391041
"""
10401042
unflatten(vnv::VarNamedVector, vals::AbstractVector)
10411043

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
66
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
77
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1011
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1112
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -18,6 +19,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1819
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1920
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2021
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
22+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2123
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2224
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2325
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -35,6 +37,7 @@ Accessors = "0.1"
3537
Bijectors = "0.15.1"
3638
Combinatorics = "1"
3739
Compat = "4.3.0"
40+
DifferentiationInterface = "0.6"
3841
Distributions = "0.25"
3942
DistributionsAD = "0.6.3"
4043
Documenter = "1"
@@ -44,6 +47,7 @@ LogDensityProblems = "2"
4447
LogDensityProblemsAD = "1.7.0"
4548
MCMCChains = "6.0.4"
4649
MacroTools = "0.5.6"
50+
Mooncake = "0.4.59"
4751
ReverseDiff = "1"
4852
StableRNGs = "1"
4953
Tracker = "0.2.23"

test/ad.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "AD: ForwardDiff and ReverseDiff" begin
1+
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
22
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
33
f = DynamicPPL.LogDensityFunction(m)
44
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
@@ -17,11 +17,20 @@
1717
θ = convert(Vector{Float64}, varinfo[:])
1818
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
1919

20-
@testset "ReverseDiff with compile=$compile" for compile in (false, true)
21-
adtype = ADTypes.AutoReverseDiff(; compile=compile)
22-
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
23-
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
24-
@test grad ref_grad
20+
@testset "$adtype" for adtype in [
21+
ADTypes.AutoReverseDiff(; compile=false),
22+
ADTypes.AutoReverseDiff(; compile=true),
23+
ADTypes.AutoMooncake(; config=nothing),
24+
]
25+
# Mooncake can't currently handle something that is going on in
26+
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
27+
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
28+
@test_broken 1 == 0
29+
else
30+
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
31+
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
32+
@test grad ref_grad
33+
end
2534
end
2635
end
2736
end

test/ext/DynamicPPLMooncakeExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@testset "DynamicPPLMooncakeExt" begin
2+
Mooncake.TestUtils.test_rule(
3+
StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true
4+
)
5+
end

0 commit comments

Comments
 (0)