Skip to content

Commit 0289358

Browse files
torfjeldegithub-actions[bot]yebai
authored
link and invlink should correctly work with Selector and thus Gibbs (#542)
* link and invlink should correctly work with Selector etc. * more fixes to link and invlink * formatting * added simple tests for usage of selectors * bumped patch version * fied typos * added missing _getvns_link for UntypedVarInfo * simplify `_getvns_link` for TypedVarInfo * Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added Compat as dep so we can make use of certain features, e.g. Returns * forgot using Compat * Apply suggestions from code review Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent d204fcb commit 0289358

File tree

4 files changed

+112
-24
lines changed

4 files changed

+112
-24
lines changed

Project.toml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.18"
3+
version = "0.23.19"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
88
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
99
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1112
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1213
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -21,13 +22,20 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2223
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2324

25+
[weakdeps]
26+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
27+
28+
[extensions]
29+
DynamicPPLMCMCChainsExt = ["MCMCChains"]
30+
2431
[compat]
2532
AbstractMCMC = "2, 3.0, 4"
2633
AbstractPPL = "0.6"
2734
BangBang = "0.3"
2835
Bijectors = "0.13"
2936
ChainRulesCore = "0.9.7, 0.10, 1"
3037
ConstructionBase = "1.5.4"
38+
Compat = "4"
3139
Distributions = "0.23.8, 0.24, 0.25"
3240
DocStringExtensions = "0.8, 0.9"
3341
LogDensityProblems = "2"
@@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1"
3947
ZygoteRules = "0.2"
4048
julia = "1.6"
4149

42-
[extensions]
43-
DynamicPPLMCMCChainsExt = ["MCMCChains"]
44-
4550
[extras]
4651
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
47-
48-
[weakdeps]
49-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module DynamicPPL
33
using AbstractMCMC: AbstractSampler, AbstractChains
44
using AbstractPPL
55
using Bijectors
6+
using Compat
67
using Distributions
78
using OrderedCollections: OrderedDict
89

src/varinfo.jl

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -902,33 +902,59 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f)
902902
return vi
903903
end
904904

905+
# HACK: We need `SampleFromPrior` to result in ALL values which are in need
906+
# of a transformation to be transformed. `_getvns` will by default return
907+
# an empty iterable for `SampleFromPrior`, so we need to override it here.
908+
# This is quite hacky, but seems safer than changing the behavior of `_getvns`.
909+
_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl)
910+
_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing
911+
function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior)
912+
return map(Returns(nothing), varinfo.metadata)
913+
end
914+
905915
function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model)
906-
return _link(varinfo)
916+
return _link(varinfo, spl)
907917
end
908918

909-
function _link(varinfo::UntypedVarInfo)
919+
function _link(varinfo::UntypedVarInfo, spl::AbstractSampler)
910920
varinfo = deepcopy(varinfo)
911921
return VarInfo(
912-
_link_metadata!(varinfo, varinfo.metadata),
922+
_link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
913923
Base.Ref(getlogp(varinfo)),
914924
Ref(get_num_produce(varinfo)),
915925
)
916926
end
917927

918-
function _link(varinfo::TypedVarInfo)
928+
function _link(varinfo::TypedVarInfo, spl::AbstractSampler)
919929
varinfo = deepcopy(varinfo)
920-
md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata)
921-
# TODO: Update logp, etc.
930+
md = _link_metadata_namedtuple!(
931+
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
932+
)
922933
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
923934
end
924935

925-
function _link_metadata!(varinfo::VarInfo, metadata::Metadata)
936+
@generated function _link_metadata_namedtuple!(
937+
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
938+
) where {names,space}
939+
vals = Expr(:tuple)
940+
for f in names
941+
if inspace(f, space) || length(space) == 0
942+
push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f)))
943+
else
944+
push!(vals.args, :(metadata.$f))
945+
end
946+
end
947+
948+
return :(NamedTuple{$names}($vals))
949+
end
950+
function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
926951
vns = metadata.vns
927952

928953
# Construct the new transformed values, and keep track of their lengths.
929954
vals_new = map(vns) do vn
930955
# Return early if we're already in unconstrained space.
931-
if istrans(varinfo, vn)
956+
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
957+
if istrans(varinfo, vn) || (target_vns !== nothing && vn target_vns)
932958
return metadata.vals[getrange(metadata, vn)]
933959
end
934960

@@ -972,32 +998,49 @@ end
972998
function invlink(
973999
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model
9741000
)
975-
return _invlink(varinfo)
1001+
return _invlink(varinfo, spl)
9761002
end
9771003

978-
function _invlink(varinfo::UntypedVarInfo)
1004+
function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler)
9791005
varinfo = deepcopy(varinfo)
9801006
return VarInfo(
981-
_invlink_metadata!(varinfo, varinfo.metadata),
1007+
_invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)),
9821008
Base.Ref(getlogp(varinfo)),
9831009
Ref(get_num_produce(varinfo)),
9841010
)
9851011
end
9861012

987-
function _invlink(varinfo::TypedVarInfo)
1013+
function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler)
9881014
varinfo = deepcopy(varinfo)
989-
md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata)
990-
# TODO: Update logp, etc.
1015+
md = _invlink_metadata_namedtuple!(
1016+
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl))
1017+
)
9911018
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
9921019
end
9931020

994-
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata)
1021+
@generated function _invlink_metadata_namedtuple!(
1022+
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space}
1023+
) where {names,space}
1024+
vals = Expr(:tuple)
1025+
for f in names
1026+
if inspace(f, space) || length(space) == 0
1027+
push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)))
1028+
else
1029+
push!(vals.args, :(metadata.$f))
1030+
end
1031+
end
1032+
1033+
return :(NamedTuple{$names}($vals))
1034+
end
1035+
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns)
9951036
vns = metadata.vns
9961037

9971038
# Construct the new transformed values, and keep track of their lengths.
9981039
vals_new = map(vns) do vn
999-
# Return early if we're already in constrained space.
1000-
if !istrans(varinfo, vn)
1040+
# Return early if we're already in constrained space OR if we're not
1041+
# supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler.
1042+
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
1043+
if !istrans(varinfo, vn) || (target_vns !== nothing && vn target_vns)
10011044
return metadata.vals[getrange(metadata, vn)]
10021045
end
10031046

test/varinfo.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# A simple "algorithm" which only has `s` variables in its space.
2+
struct MySAlg end
3+
DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
4+
15
@testset "varinfo.jl" begin
26
@testset "TypedVarInfo" begin
37
@model gdemo(x, y) = begin
@@ -421,4 +425,42 @@
421425
end
422426
end
423427
end
428+
429+
@testset "VarInfo with selectors" begin
430+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
431+
varinfo = VarInfo(model)
432+
selector = DynamicPPL.Selector()
433+
spl = Sampler(MySAlg(), model, selector)
434+
435+
vns = DynamicPPL.TestUtils.varnames(model)
436+
vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns)
437+
vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns)
438+
for vn in vns_s
439+
DynamicPPL.updategid!(varinfo, vn, spl)
440+
end
441+
442+
# Should only get the variables subsumed by `@varname(s)`.
443+
@test varinfo[spl] ==
444+
mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s)
445+
446+
# `link`
447+
varinfo_linked = DynamicPPL.link(varinfo, spl, model)
448+
# `s` variables should be linked
449+
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
450+
# `m` variables should NOT be linked
451+
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
452+
# And `varinfo` should be unchanged
453+
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns)
454+
455+
# `invlink`
456+
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model)
457+
# `s` variables should no longer be linked
458+
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s)
459+
# `m` variables should still not be linked
460+
@test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m)
461+
# And `varinfo_linked` should be unchanged
462+
@test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s)
463+
@test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m)
464+
end
465+
end
424466
end

0 commit comments

Comments
 (0)