From 4dc2a7237954c2185fffe823a15e69ea300d57bb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 15:24:57 +0000 Subject: [PATCH 01/40] Remove selector stuff from varinfo tests --- test/varinfo.jl | 259 ++++++++---------------------------------------- 1 file changed, 44 insertions(+), 215 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 9a55cffb9..c6fa78658 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,8 +1,3 @@ -# Dummy algorithm for testing -# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) -struct MyAlg{space} end -DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space - function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -19,16 +14,13 @@ function check_varinfo_keys(varinfo, vns) end end -function randr( - vi::DynamicPPL.VarInfo, - vn::VarName, - dist::Distribution, - spl::DynamicPPL.Sampler, - count::Bool=false, -) +""" +Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. +""" +function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) if !haskey(vi, vn) r = rand(dist) - push!!(vi, vn, r, dist, spl) + push!!(vi, vn, r, dist) r elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") @@ -37,8 +29,6 @@ function randr( DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) r else - count && checkindex(vn, vi, spl) - DynamicPPL.updategid!(vi, vn, spl) vi[vn] end end @@ -66,7 +56,6 @@ end tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] @test meta.orders[ind] == fmeta.orders[tind] - @test meta.gids[ind] == fmeta.gids[tind] for flag in keys(meta.flags) @test meta.flags[flag][ind] == fmeta.flags[flag][tind] end @@ -89,22 +78,6 @@ end vn2 = @varname x[1][2] @test vn2 == vn1 @test hash(vn2) == hash(vn1) - @test inspace(vn1, (:x,)) - - # Tests for `inspace` - space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) - - @test inspace(@varname(x), space) - @test inspace(@varname(y), space) - @test inspace(@varname(x[1]), space) - @test inspace(@varname(z[1][1]), space) - @test inspace(@varname(z[1][:]), space) - @test inspace(@varname(z[1][2:3:10]), space) - @test inspace(@varname(M[[2, 3], 1]), space) - @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) - @test inspace(@varname(M[1, [2, 4, 6]]), space) - @test !inspace(@varname(z[2]), space) - @test !inspace(@varname(z), space) function test_base!!(vi_original) vi = empty!!(vi_original) @@ -114,38 +87,31 @@ end vn = @varname x dist = Normal(0, 1) r = rand(dist) - gid = DynamicPPL.Selector() @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @test length(vi[vn]) == 1 - @test length(vi[SampleFromPrior()]) == 1 - @test vi[vn] == r - @test vi[SampleFromPrior()][1] == r vi = DynamicPPL.setindex!!(vi, 2 * r, vn) @test vi[vn] == 2 * r - @test vi[SampleFromPrior()][1] == 2 * r - vi = DynamicPPL.setindex!!(vi, [3 * r], SampleFromPrior()) - @test vi[vn] == 3 * r - @test vi[SampleFromPrior()][1] == 3 * r # TODO(mhauru) Implement these functions for other VarInfo types too. if vi isa DynamicPPL.VectorVarInfo delete!(vi, vn) @test isempty(vi) - vi = push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) end vi = empty!!(vi) @test isempty(vi) - return push!!(vi, vn, r, dist, gid) + vi = push!!(vi, vn, r, dist) + @test ~isempty(vi) end vi = VarInfo() @@ -182,9 +148,8 @@ end vn_x = @varname x dist = Normal(0, 1) r = rand(dist) - gid = Selector() - push!!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist) # del is set by default @test !is_flagged(vi, vn_x, "del") @@ -204,35 +169,13 @@ end vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) typed_vi = TypedVarInfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 end - @testset "setgid!" begin - vi = VarInfo(DynamicPPL.Metadata()) - meta = vi.metadata - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - gid1 = Selector() - gid2 = Selector(2, :HMC) - - push!!(vi, vn, r, dist, gid1) - @test meta.gids[meta.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - - vi = empty!!(TypedVarInfo(vi)) - meta = vi.metadata - push!!(vi, vn, r, dist, gid1) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) - @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) - end - @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -397,10 +340,9 @@ end """ function test_setval!(model, chain; sample_idx=1, chain_idx=1) var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] + θ_old = var_info[:] DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] + θ_new = var_info[:] @test θ_old != θ_new vals = DynamicPPL.values_as(var_info, OrderedDict) iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) @@ -448,9 +390,9 @@ end # Check that linking and invlinking set the `trans` flag accordingly v = copy(meta.vals) - link!!(vi, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.vns) - invlink!!(vi, model) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.vns) @test meta.vals ≈ v atol = 1e-10 @@ -461,21 +403,20 @@ end @test all(x -> !istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) - link!!(vi, model) + vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.s.vns) @test all(x -> istrans(vi, x), meta.m.vns) - invlink!!(vi, model) + vi = invlink!!(vi, model) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 # Transform only one variable (`s`) but not the others (`m`) - spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) - link!!(vi, spl, model) + vi = link!!(vi, @varname(s), model) @test all(x -> istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) - invlink!!(vi, spl, model) + vi = invlink!!(vi, @varname(s), model) @test all(x -> !istrans(vi, x), meta.s.vns) @test all(x -> !istrans(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @@ -856,62 +797,6 @@ end @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end - - # The below used to error, testing to avoid regression. - @testset "merge gids" begin - gidset_left = Set([Selector(1)]) - vi_left = VarInfo() - vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) - gidset_right = Set([Selector(2)]) - vi_right = VarInfo() - vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) - varinfo_merged = merge(vi_left, vi_right) - @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left - @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right - end - end - - @testset "VarInfo with selectors" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo( - model, - DynamicPPL.SampleFromPrior(), - DynamicPPL.DefaultContext(), - DynamicPPL.Metadata(), - ) - selector = DynamicPPL.Selector() - spl = Sampler(MyAlg{(:s,)}(), model, selector) - - vns = DynamicPPL.TestUtils.varnames(model) - vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) - vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) - for vn in vns_s - DynamicPPL.updategid!(varinfo, vn, spl) - end - - # Should only get the variables subsumed by `@varname(s)`. - @test varinfo[spl] == - mapreduce(Base.Fix1(DynamicPPL.getindex_internal, varinfo), vcat, vns_s) - - # `link` - varinfo_linked = DynamicPPL.link(varinfo, spl, model) - # `s` variables should be linked - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - # `m` variables should NOT be linked - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - # And `varinfo` should be unchanged - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) - - # `invlink` - varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) - # `s` variables should no longer be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) - # `m` variables should still not be linked - @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) - # And `varinfo_linked` should be unchanged - @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) - @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) - end end @testset "sampling from linked varinfo" begin @@ -1014,25 +899,22 @@ end vi = DynamicPPL.VarInfo() dists = [Categorical([0.7, 0.3]), Normal()] - spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) - # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1040,13 +922,13 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 @@ -1054,21 +936,21 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_b, dists[2]) + randr(vi, vn_z2, dists[1]) + randr(vi, vn_a2, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_z3, dists[1]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 2] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 DynamicPPL.reset_num_produce!(vi) - DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @@ -1076,69 +958,16 @@ end @test DynamicPPL.is_flagged(vi, vn_z3, "del") DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) + randr(vi, vn_z1, dists[1]) + randr(vi, vn_a1, dists[2]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_z2, dists[1]) DynamicPPL.increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) + randr(vi, vn_z3, dists[1]) + randr(vi, vn_a2, dists[2]) @test vi.metadata.z.orders == [1, 2, 3] @test vi.metadata.a.orders == [1, 3] @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 end - - @testset "varinfo ranges" begin - @model empty_model() = x = 1 - dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] - - function test_varinfo!(vi) - spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) - vn_w = @varname w - randr(vi, vn_w, dists[1], spl2, true) - - vn_x = @varname x - vn_y = @varname y - vn_z = @varname z - vns = [vn_x, vn_y, vn_z] - - spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) - for i in 1:3 - r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] - @test sum(val - r) <= 1e-9 - end - - idcs = DynamicPPL._getidcs(vi, spl1) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 - else - @test length(idcs) == 3 - end - @test length(vi[spl1]) == 7 - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 - else - @test length(idcs) == 1 - end - @test length(vi[spl2]) == 1 - - vn_u = @varname u - randr(vi, vn_u, dists[1], spl2, true) - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 - else - @test length(idcs) == 2 - end - @test length(vi[spl2]) == 2 - end - vi = DynamicPPL.VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) - end end From 9b492a33b7d6b007b446a4ee2e8e83f7c17485cb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 15:49:03 +0000 Subject: [PATCH 02/40] Implement link and invlink for varnames rather than samplers --- src/abstract_varinfo.jl | 48 ++++++++++++++--- src/threadsafe.jl | 56 ++++++++++++++----- src/transforming.jl | 22 +++++--- src/varinfo.jl | 116 +++++++++++++++++++++++++++++++++++----- 4 files changed, 202 insertions(+), 40 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 3f513d71d..a3a3c9c78 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,8 +537,17 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end +# TODO(mhauru) The fact that we need to to define this type is a sign that the link/invlink +# API is hard to understand. To be fixed by removing samplers from it. +SamplerOrVarName = Union{ + AbstractSampler,VarName,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space, using the transformation `t`, @@ -552,13 +561,19 @@ link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link!!(t, vi, SampleFromPrior(), model) end -function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) +function link!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) # Use `default_transformation` to decide which transformation to use if none is specified. - return link!!(default_transformation(model, vi), vi, spl, model) + return link!!(default_transformation(model, vi), vi, spl_or_vn, model) +end +function link!!(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) + return link!!(t, deepcopy(vi), (vn,), model) end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. @@ -571,13 +586,19 @@ link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model) function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link(t, deepcopy(vi), SampleFromPrior(), model) end -function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) +function link(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) # Use `default_transformation` to decide which transformation to use if none is specified. - return link(default_transformation(model, vi), deepcopy(vi), spl, model) + return link(default_transformation(model, vi), deepcopy(vi), spl_or_vn, model) +end +function link(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) + return link(t, deepcopy(vi), (vn,), model) end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space, using the (inverse of) @@ -591,9 +612,14 @@ invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink!!(t, vi, SampleFromPrior(), model) end -function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) +function invlink!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) # Here we extract the `transformation` from `vi` rather than using the default one. - return invlink!!(transformation(vi), vi, spl, model) + return invlink!!(transformation(vi), vi, spl_or_vn, model) +end +function invlink!!( + t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model +) + return invlink!!(t, vi, (vn,), model) end # Vector-based ones. @@ -629,6 +655,9 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) @@ -642,8 +671,11 @@ invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), mode function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink(t, vi, SampleFromPrior(), model) end -function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) - return invlink(transformation(vi), vi, spl, model) +function invlink(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) + return invlink(transformation(vi), vi, spl_or_vn, model) +end +function invlink(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) + return invlink(t, vi, (vn,), model) end """ diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cedb0efad..bb60f7bcf 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,28 +81,44 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) +SamplerOrVarNameIterator = Union{ + AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + function link!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl_or_vn, model) end function invlink!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl_or_vn, model) end function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl_or_vn, model) end function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::AbstractTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl_or_vn, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. @@ -110,13 +126,19 @@ end # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. function link!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end function invlink!!( - ::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -125,15 +147,21 @@ function invlink!!( end function link( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return link!!(t, deepcopy(vi), spl, model) + return link!!(t, deepcopy(vi), spl_or_vn, model) end function invlink( - t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::ThreadSafeVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return invlink!!(t, deepcopy(vi), spl, model) + return invlink!!(t, deepcopy(vi), spl_or_vn, model) end function maybe_invlink_before_eval!!( diff --git a/src/transforming.jl b/src/transforming.jl index 1f6c55e24..d0f3774c5 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,14 +91,18 @@ function dot_tilde_assume( return r, lp, vi end +SamplerOrVarNameIterator = Union{ + AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -107,13 +111,19 @@ function invlink!!( end function link( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::AbstractVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return link!!(t, deepcopy(vi), spl, model) + return link!!(t, deepcopy(vi), spl_or_vn, model) end function invlink( - t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, + vi::AbstractVarInfo, + spl_or_vn::SamplerOrVarNameIterator, + model::Model, ) - return invlink!!(t, deepcopy(vi), spl, model) + return invlink!!(t, deepcopy(vi), spl_or_vn, model) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3ebb505e0..d9c1247fc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1201,29 +1201,40 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end +SamplerOrVarNameIterator = Union{ + AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} +} + # X -> R for all variables associated with given sampler -function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function link!!( + t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model +) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, spl, model) + has_varnamedvector(vi) && return link(t, vi, spl_or_vn, model) # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, spl) + _link!(vi, spl_or_vn) return vi end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + spl_or_vn::SamplerOrVarNameIterator, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl_or_vn, model) end function _link!(vi::UntypedVarInfo, spl::AbstractSampler) + return _link!(vi, _getvns(vi, spl)) +end + +function _link!( + vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) # TODO: Change to a lazy iterator over `vns` - vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) @@ -1234,6 +1245,7 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to link a linked vi") end end + function _link!(vi::TypedVarInfo, spl::AbstractSampler) return _link!(vi, spl, Val(getspace(spl))) end @@ -1268,26 +1280,70 @@ end return expr end +""" + filter_subsumed(vns1, vns2) + +Return the subset of `vns2` that are subsumed by any variable in `vns1`. +""" +function filter_subsumed(vns1, vns2) + return filter(x -> any(subsumes(y, x) for y in vns1), vns2) +end + +function _link!( + vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) + return _link!(vi.metadata, vi, vns) +end +@generated function _link!( + metadata::NamedTuple{names}, + vi, + vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}}, +) where {names,space} + expr = Expr(:block) + for f in names + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns, f_vns) + if !isempty(f_vns) + if !istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, true, vn) + end + else + @warn("[DynamicPPL] attempt to link a linked vi") + end + end + end, + ) + end + return expr +end + # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model + t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, spl, model) + has_varnamedvector(vi) && return invlink(t, vi, spl_or_vn, model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, spl) + _invlink!(vi, spl_or_vn) return vi end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + spl_or_vn::SamplerOrVarNameIterator, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl_or_vn, model) end function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) @@ -1299,7 +1355,11 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode end function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - vns = _getvns(vi, spl) + return _invlink!(vi, _getvns(vi, spl)) +end +function _invlink!( + vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1310,6 +1370,7 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end + function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) return _invlink!(vi, spl, Val(getspace(spl))) end @@ -1344,6 +1405,37 @@ end return expr end +function _invlink!( + vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} +) + return _invlink!(vi.metadata, vi, vns) +end +@generated function _invlink!(metadata::NamedTuple{names}, vi, vns) where {names} + expr = Expr(:block) + for f in names + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns, f_vns) + if !isempty(f_vns) + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, false, vn) + end + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") + end + end + end, + ) + end + return expr +end + function _inner_transform!(vi::VarInfo, vn::VarName, f) return _inner_transform!(getmetadata(vi, vn), vi, vn, f) end From b508f08a6faef408d409d52859dd55efb4ce80f2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 15:49:49 +0000 Subject: [PATCH 03/40] Replace set_retained_vns_del_by_spl! with set_retained_vns_del! --- src/varinfo.jl | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index d9c1247fc..8d68f2b86 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -842,6 +842,9 @@ Returns a tuple of the unique symbols of random variables sampled in `vi`. syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) +_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) +_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) + # Get all indices of variables belonging to SampleFromPrior: # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to # the SampleFromPrior sampler @@ -2109,37 +2112,36 @@ function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bo end """ - set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) + set_retained_vns_del!(vi::VarInfo) Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. """ -function set_retained_vns_del_by_spl!(vi::UntypedVarInfo, spl::Sampler) - # Get the indices of `vns` that belong to `spl` as a vector - gidcs = _getidcs(vi, spl) +function set_retained_vns_del!(vi::UntypedVarInfo) + idcs = _getidcs(vi) if get_num_produce(vi) == 0 - for i in length(gidcs):-1:1 - vi.metadata.flags["del"][gidcs[i]] = true + for i in length(idcs):-1:1 + vi.metadata.flags["del"][idcs[i]] = true end else for i in 1:length(vi.orders) - if i in gidcs && vi.orders[i] > get_num_produce(vi) + if i in idcs && vi.orders[i] > get_num_produce(vi) vi.metadata.flags["del"][i] = true end end end return nothing end -function set_retained_vns_del_by_spl!(vi::TypedVarInfo, spl::Sampler) +function set_retained_vns_del!(vi::TypedVarInfo) # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol - gidcs = _getidcs(vi, spl) - return _set_retained_vns_del_by_spl!(vi.metadata, gidcs, get_num_produce(vi)) + idcs = _getidcs(vi) + return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end -@generated function _set_retained_vns_del_by_spl!( - metadata, gidcs::NamedTuple{names}, num_produce +@generated function _set_retained_vns_del!( + metadata, idcs::NamedTuple{names}, num_produce ) where {names} expr = Expr(:block) for f in names - f_gidcs = :(gidcs.$f) + f_idcs = :(idcs.$f) f_orders = :(metadata.$f.orders) f_flags = :(metadata.$f.flags) push!( @@ -2147,12 +2149,12 @@ end quote # Set the flag for variables with symbol `f` if num_produce == 0 - for i in length($f_gidcs):-1:1 - $f_flags["del"][$f_gidcs[i]] = true + for i in length($f_idcs):-1:1 + $f_flags["del"][$f_idcs[i]] = true end else for i in 1:length($f_orders) - if i in $f_gidcs && $f_orders[i] > num_produce + if i in $f_idcs && $f_orders[i] > num_produce $f_flags["del"][i] = true end end From b8880d1d2169ca4eaf4612635d4b26cfa0bb08fc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 16 Jan 2025 17:38:02 +0000 Subject: [PATCH 04/40] Make linking tests more extensive --- test/varinfo.jl | 56 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index c6fa78658..fd1c9a2e9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -374,13 +374,21 @@ end end @testset "link!! and invlink!!" begin - @model gdemo(x, y) = begin + @model gdemo(a, b, ::Type{T}=Float64) where {T} = begin s ~ InverseGamma(2, 3) m ~ Uniform(0, 2) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) + x = Vector{T}(undef, length(a)) + x .~ Normal(m, sqrt(s)) + y = Vector{T}(undef, length(a)) + for i in eachindex(y) + y[i] ~ Normal(m, sqrt(s)) + end + a .~ Normal(m, sqrt(s)) + for i in eachindex(b) + b[i] ~ Normal(x[i] * y[i], sqrt(s)) + end end - model = gdemo(1.0, 2.0) + model = gdemo([1.0, 1.5], [2.0, 2.5]) # Check that instantiating the model does not perform linking vi = VarInfo() @@ -399,10 +407,13 @@ end # Check that linking and invlinking preserves the values vi = TypedVarInfo(vi) meta = vi.metadata - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) + v_x = copy(meta.x.vals) + v_y = copy(meta.y.vals) + + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) vi = link!!(vi, model) @test all(x -> istrans(vi, x), meta.s.vns) @test all(x -> istrans(vi, x), meta.m.vns) @@ -412,15 +423,30 @@ end @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 - # Transform only one variable (`s`) but not the others (`m`) - vi = link!!(vi, @varname(s), model) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - vi = invlink!!(vi, @varname(s), model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 + # Transform only one variable + all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) + for vn in [ + @varname(s), + @varname(m), + @varname(x), + @varname(y), + @varname(x[2]), + @varname(y[2]) + ] + target_vns = filter(x -> subsumes(vn, x), all_vns) + other_vns = filter(x -> !subsumes(vn, x), all_vns) + @test !isempty(target_vns) + @test !isempty(other_vns) + vi = link!!(vi, vn, model) + @test all(x -> istrans(vi, x), target_vns) + @test all(x -> !istrans(vi, x), other_vns) + vi = invlink!!(vi, vn, model) + @test all(x -> !istrans(vi, x), all_vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + @test meta.x.vals ≈ v_x atol = 1e-10 + @test meta.y.vals ≈ v_y atol = 1e-10 + end end @testset "istrans" begin From 99a8490631b10e9696f501d0555e38770b18128c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 14:51:18 +0000 Subject: [PATCH 05/40] Remove sampler indexing from link methods (but not invlink) --- src/abstract_varinfo.jl | 48 ++++++---- src/simple_varinfo.jl | 4 +- src/threadsafe.jl | 27 ++---- src/transforming.jl | 10 +- src/varinfo.jl | 198 +++++++++++++++++++++++++--------------- 5 files changed, 170 insertions(+), 117 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index a3a3c9c78..26238c12e 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -548,7 +548,6 @@ SamplerOrVarName = Union{ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space, using the transformation `t`, mutating `vi` if possible. @@ -557,16 +556,25 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ -link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, vi, SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function link!!(vi::AbstractVarInfo, model::Model) + return link!!(default_transformation(model, vi), vi, model) +end +function link!!(vi::AbstractVarInfo, vns, model::Model) + return link!!(default_transformation(model, vi), vi, vns, model) end -function link!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link!!(default_transformation(model, vi), vi, spl_or_vn, model) +# If no variable names are provided, link all variables. +function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return link!!(t, vi, vns, model) end +# Wrap a single VarName in a singleton tuple. function link!!(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link!!(t, deepcopy(vi), (vn,), model) + return link!!(t, vi, (vn,), model) end """ @@ -574,7 +582,6 @@ end link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. @@ -582,16 +589,25 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ -link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model) -function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link(t, deepcopy(vi), SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function link(vi::AbstractVarInfo, model::Model) + return link(default_transformation(model, vi), vi, model) +end +function link(vi::AbstractVarInfo, vns, model::Model) + return link(default_transformation(model, vi), vi, vns, model) end -function link(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - # Use `default_transformation` to decide which transformation to use if none is specified. - return link(default_transformation(model, vi), deepcopy(vi), spl_or_vn, model) +# If no variable names are provided, link all variables. +function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return link(t, vi, vns, model) end +# Wrap a single VarName in a singleton tuple. function link(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link(t, deepcopy(vi), (vn,), model) + return link(t, vi, (vn,), model) end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b6a84238e..6bb723b29 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,7 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, + ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, model::Model, ) # TODO: Make sure that `spl` is respected. @@ -695,7 +695,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - spl::AbstractSampler, + ::AbstractSampler, model::Model, ) # TODO: Make sure that `spl` is respected. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index bb60f7bcf..c5a77c3ef 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -84,14 +84,12 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl SamplerOrVarNameIterator = Union{ AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} } +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} function link!!( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, vns, model) end function invlink!!( @@ -104,12 +102,9 @@ function invlink!!( end function link( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) end function invlink( @@ -126,10 +121,7 @@ end # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end @@ -147,12 +139,9 @@ function invlink!!( end function link( - t::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return link!!(t, deepcopy(vi), spl_or_vn, model) + return link!!(t, deepcopy(vi), vns, model) end function invlink( diff --git a/src/transforming.jl b/src/transforming.jl index d0f3774c5..6acaf787c 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -94,9 +94,10 @@ end SamplerOrVarNameIterator = Union{ AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} } +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model + t::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model ) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end @@ -111,12 +112,9 @@ function invlink!!( end function link( - t::DynamicTransformation, - vi::AbstractVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model ) - return link!!(t, deepcopy(vi), spl_or_vn, model) + return link!!(t, deepcopy(vi), vns, model) end function invlink( diff --git a/src/varinfo.jl b/src/varinfo.jl index 8d68f2b86..4c4125ad8 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1207,35 +1207,37 @@ end SamplerOrVarNameIterator = Union{ AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} } +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} + +# Specialise link!! without varnames provided for TypedVarInfo. The usual version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps +# keep the downstread calls to link!! type stable. +function link!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) + return link!!(t, vi, all_varnames_namedtuple(vi), model) +end # X -> R for all variables associated with given sampler -function link!!( - t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model -) +function link!!(t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, spl_or_vn, model) + has_varnamedvector(vi) && return link(t, vi, vns, model) # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, spl_or_vn) + _link!(vi, vns) return vi end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl_or_vn::SamplerOrVarNameIterator, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl_or_vn, model) -end - -function _link!(vi::UntypedVarInfo, spl::AbstractSampler) - return _link!(vi, _getvns(vi, spl)) + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end function _link!( - vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} + vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} ) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) @@ -1249,24 +1251,30 @@ function _link!( end end -function _link!(vi::TypedVarInfo, spl::AbstractSampler) - return _link!(vi, spl, Val(getspace(spl))) +""" + filter_subsumed(vns1, vns2) + +Return the subset of `vns2` that are subsumed by any variable in `vns1`. +""" +function filter_subsumed(vns1, vns2) + return filter(x -> any(subsumes(y, x) for y in vns1), vns2) end -function _link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, spaceval) + +function _link!(vi::TypedVarInfo, vns::VarNameCollection) + return _link!(vi.metadata, vi, vns) end @generated function _link!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{names}, vi, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} +) where {names} expr = Expr(:block) for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if ~istrans(vi, f_vns[1]) + push!( + expr.args, + quote + f_vns = vi.metadata.$f.vns + f_vns = filter_subsumed(vns, f_vns) + if !isempty(f_vns) + if !istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) @@ -1276,39 +1284,26 @@ end else @warn("[DynamicPPL] attempt to link a linked vi") end - end, - ) - end + end + end, + ) end return expr end -""" - filter_subsumed(vns1, vns2) - -Return the subset of `vns2` that are subsumed by any variable in `vns1`. -""" -function filter_subsumed(vns1, vns2) - return filter(x -> any(subsumes(y, x) for y in vns1), vns2) -end - -function _link!( - vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} -) - return _link!(vi.metadata, vi, vns) -end @generated function _link!( - metadata::NamedTuple{names}, - vi, - vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}}, -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names + for f in metadata_names + if !(f in vns_names) + continue + end push!( expr.args, quote f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns, f_vns) + f_vns = filter_subsumed(vns.$f, f_vns) if !isempty(f_vns) if !istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform @@ -1361,7 +1356,7 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) return _invlink!(vi, _getvns(vi, spl)) end function _invlink!( - vi::UntypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} + vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} ) if istrans(vi, vns[1]) for vn in vns @@ -1382,7 +1377,7 @@ function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) return _invlink!(vi.metadata, vi, vns, spaceval) end @generated function _invlink!( - metadata::NamedTuple{names}, vi, vns, ::Val{space} + ::NamedTuple{names}, vi, vns, ::Val{space} ) where {names,space} expr = Expr(:block) for f in names @@ -1408,12 +1403,10 @@ end return expr end -function _invlink!( - vi::TypedVarInfo, vns::Union{AbstractVector{<:VarName},NTuple{N,VarName} where {N}} -) +function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) return _invlink!(vi.metadata, vi, vns) end -@generated function _invlink!(metadata::NamedTuple{names}, vi, vns) where {names} +@generated function _invlink!(::NamedTuple{names}, vi, vns) where {names} expr = Expr(:block) for f in names push!( @@ -1466,59 +1459,116 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) return map(Returns(nothing), varinfo.metadata) end -function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(model, varinfo, spl) +function link( + ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model +) + return _link(model, varinfo, vns) end function link( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end function _link( - model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, spl::AbstractSampler + model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, vns::VarNameCollection ) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _link_metadata!!(model, varinfo, varinfo.metadata, vns), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _link(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +""" + unique_syms(vns::T) where {T<:NTuple{N,VarName}} + +Return the unique symbols of the variables in `vns`. +""" +@generated function unique_syms(vns::T) where {T<:NTuple{N,VarName}} where {N} + retval = Expr(:tuple) + syms = [first(vn.parameters) for vn in T.parameters] + for sym in unique(syms) + push!(retval.args, QuoteNode(sym)) + end + return retval +end + +""" + varname_namedtuple(vns::NTuple{N,VarName}) where {N} + varname_namedtuple(vns::AbstractVector{<:VarName}) + +Return a `NamedTuple` of the variables in `vns` grouped by symbol. +""" +function varname_namedtuple(vns::NTuple{N,VarName} where {N}) + syms = unique_syms(vns) + elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) + return NamedTuple{syms}(elements) +end + +# This method is type unstable, but that can't be helped: The problem is inherently type +# unstable if there are VarNames with multiple symbols in a Vector. +function varname_namedtuple(vns::AbstractVector{<:VarName}) + syms = tuple(unique(map(getsym, vns))...) + elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) + return NamedTuple{syms}(elements) +end + +# A simpler, type stable implementation when all the VarNames in a Vector have the same +# symbol. +function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} + return NamedTuple{(T,)}((vns,)) +end + +varname_namedtuple(vns::NamedTuple) = vns + +""" + all_varnames_namedtuple(vi::AbstractVarInfo) + +Return a `NamedTuple` of the variables in `vi` grouped by symbol. +""" +all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) + +@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :($f = keys(md.$f))) + end + return expr +end + +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) - md = _link_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + vns_namedtuple = varname_namedtuple(vns) + md = _link_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _link_metadata_namedtuple!( +@generated function _link_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if f in vns_names push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end -function _link_metadata!!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) +function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1691,7 +1741,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ end function _invlink_metadata!!( - model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns + ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns for vn in vns From 4a79b1f66e267a8e7a4951bd81599d979e66b899 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 17:17:07 +0000 Subject: [PATCH 06/40] Remove indexing by samplers from invlink --- src/abstract_varinfo.jl | 49 ++++++++++------- src/simple_varinfo.jl | 2 +- src/threadsafe.jl | 26 +++------ src/transforming.jl | 12 +--- src/varinfo.jl | 118 ++++++++++++++++------------------------ 5 files changed, 87 insertions(+), 120 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 26238c12e..f28755c9f 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,12 +537,6 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end -# TODO(mhauru) The fact that we need to to define this type is a sign that the link/invlink -# API is hard to understand. To be fixed by removing samplers from it. -SamplerOrVarName = Union{ - AbstractSampler,VarName,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} - """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) @@ -615,7 +609,6 @@ end invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space, using the (inverse of) transformation `t`, mutating `vi` if possible. @@ -624,14 +617,23 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ -invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) -function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, vi, SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function invlink!!(vi::AbstractVarInfo, model::Model) + return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - # Here we extract the `transformation` from `vi` rather than using the default one. - return invlink!!(transformation(vi), vi, spl_or_vn, model) +function invlink!!(vi::AbstractVarInfo, vns, model::Model) + return invlink!!(default_transformation(model, vi), vi, vns, model) end +# If no variable names are provided, invlink!! all variables. +function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return invlink!!(t, vi, vns, model) +end +# Wrap a single VarName in a singleton tuple. function invlink!!( t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model ) @@ -674,7 +676,6 @@ end invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) transformation `t`. @@ -683,13 +684,23 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link`](@ref). """ -invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model) -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink(t, vi, SampleFromPrior(), model) +# Use `default_transformation` to decide which transformation to use if none is specified. +function invlink(vi::AbstractVarInfo, model::Model) + return invlink(default_transformation(model, vi), vi, model) +end +function invlink(vi::AbstractVarInfo, vns, model::Model) + return invlink(default_transformation(model, vi), vi, vns, model) end -function invlink(vi::AbstractVarInfo, spl_or_vn::SamplerOrVarName, model::Model) - return invlink(transformation(vi), vi, spl_or_vn, model) +# If no variable names are provided, invlink all variables. +function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + vns = collect(keys(vi)) + # In case e.g. vns = Any[]. + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return invlink(t, vi, vns, model) end +# Wrap a single VarName in a singleton tuple. function invlink(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) return invlink(t, vi, (vn,), model) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 6bb723b29..b4e836371 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -695,7 +695,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::AbstractSampler, + ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, model::Model, ) # TODO: Make sure that `spl` is respected. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c5a77c3ef..c75ec2291 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -93,12 +93,9 @@ function link!!( end function invlink!!( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, vns, model) end function link( @@ -108,12 +105,9 @@ function link( end function invlink( - t::AbstractTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. @@ -127,10 +121,7 @@ function link!!( end function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + ::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -145,12 +136,9 @@ function link( end function invlink( - t::DynamicTransformation, - vi::ThreadSafeVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) - return invlink!!(t, deepcopy(vi), spl_or_vn, model) + return invlink!!(t, deepcopy(vi), vns, model) end function maybe_invlink_before_eval!!( diff --git a/src/transforming.jl b/src/transforming.jl index 6acaf787c..46b42d8ed 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,9 +91,6 @@ function dot_tilde_assume( return r, lp, vi end -SamplerOrVarNameIterator = Union{ - AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} function link!!( @@ -103,7 +100,7 @@ function link!!( end function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, ::SamplerOrVarNameIterator, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model ) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), @@ -118,10 +115,7 @@ function link( end function invlink( - t::DynamicTransformation, - vi::AbstractVarInfo, - spl_or_vn::SamplerOrVarNameIterator, - model::Model, + t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model ) - return invlink!!(t, deepcopy(vi), spl_or_vn, model) + return invlink!!(t, deepcopy(vi), vns, model) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 4c4125ad8..c05a42aba 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1322,26 +1322,33 @@ end return expr end +# Specialise invlink!! without varnames provided for TypedVarInfo. The usual version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps +# keep the downstread calls to link!! type stable. +function invlink!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) + return invlink!!(t, vi, all_varnames_namedtuple(vi), model) +end + # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, vi::VarInfo, spl_or_vn::SamplerOrVarNameIterator, model::Model + t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, spl_or_vn, model) + has_varnamedvector(vi) && return invlink(t, vi, vns, model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, spl_or_vn) + _invlink!(vi, vns) return vi end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - spl_or_vn::SamplerOrVarNameIterator, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl_or_vn, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) end function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) @@ -1352,9 +1359,6 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) - return _invlink!(vi, _getvns(vi, spl)) -end function _invlink!( vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} ) @@ -1369,62 +1373,33 @@ function _invlink!( end end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return _invlink!(vi, spl, Val(getspace(spl))) -end -function _invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) - vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, spaceval) +function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) + vns_namedtuple = varname_namedtuple(vns) + return _invlink!(vi.metadata, vi, vns_namedtuple) end @generated function _invlink!( - ::NamedTuple{names}, vi, vns, ::Val{space} -) where {names,space} + ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} +) where {metadata_names,vns_names} expr = Expr(:block) - for f in names - if inspace(f, space) || length(space) == 0 - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end, - ) + for f in metadata_names + if !(f in vns_names) + continue end - end - return expr -end -function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) - return _invlink!(vi.metadata, vi, vns) -end -@generated function _invlink!(::NamedTuple{names}, vi, vns) where {names} - expr = Expr(:block) - for f in names push!( expr.args, quote f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns, f_vns) - if !isempty(f_vns) - if istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") + f_vns = filter_subsumed(vns.$f, f_vns) + if istrans(vi, f_vns[1]) + # Iterate over all `f_vns` and transform + for vn in f_vns + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) + settrans!!(vi, false, vn) end + else + @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end, ) @@ -1641,56 +1616,55 @@ function _link_metadata!!( end function invlink( - ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model + ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model ) - return _invlink(model, varinfo, spl) + return _invlink(model, varinfo, vns) end function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - spl::AbstractSampler, + vns::VarNameCollection, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end -function _invlink(model::Model, varinfo::VarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) - md = _invlink_metadata_namedtuple!( - model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) - ) + vns_namedtuple = varname_namedtuple(vns) + md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _invlink_metadata_namedtuple!( +@generated function _invlink_metadata!( model::Model, varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space}, -) where {names,space} + metadata::NamedTuple{metadata_names}, + vns::NamedTuple{vns_names}, +) where {metadata_names,vns_names} vals = Expr(:tuple) - for f in names - if inspace(f, space) || length(space) == 0 + for f in metadata_names + if (f in vns_names) push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end end - return :(NamedTuple{$names}($vals)) + return :(NamedTuple{$metadata_names}($vals)) end + function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns From 090608bc66e4c5d3317f44ccfb4cad582e903a43 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 18:05:38 +0000 Subject: [PATCH 07/40] Work towards removing sampler indexing with StaticTransformation --- src/abstract_varinfo.jl | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index f28755c9f..6b7e412a8 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -644,30 +644,46 @@ end function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + ::Model, ) + # TODO(mhauru) The behavior of this before the removal of indexing with samplers was a + # bit mixed. For TypedVarInfo you could transform only a subset of the variables, but + # for UntypedVarInfo and SimpleVarInfo it was silently assumed that all variables were + # being set. Unsure if we should support this or not, but at least it now errors + # loudly. + all_vns = Set(keys(vi)) + if Set(vns) != all_vns + msg = "StaticTransforming only a subset of variables is not supported." + throw(ArgumentError(msg)) + end b = inverse(t.bijector) - x = vi[spl] + x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) + vi_new = setlogp!!(unflatten(vi, y), lp_new) return settrans!!(vi_new, t) end function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - spl::AbstractSampler, - model::Model, + vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + ::Model, ) + # TODO(mhauru) See comment in link!! above. + all_vns = Set(keys(vi)) + if Set(vns) != all_vns + msg = "StaticTransforming only a subset of variables is not supported." + throw(ArgumentError(msg)) + end b = t.bijector - y = vi[spl] + y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) + vi_new = setlogp!!(unflatten(vi, x), lp_new) return settrans!!(vi_new, NoTransformation()) end @@ -774,9 +790,14 @@ function maybe_invlink_before_eval!!( return vi end function maybe_invlink_before_eval!!( - t::StaticTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + t::StaticTransformation, vi::AbstractVarInfo, ::AbstractContext, model::Model ) - return invlink!!(t, vi, _default_sampler(context), model) + # TODO(mhauru) Why does this function need the context argument? + vns = collect(keys(vi)) + if !(eltype(vns) <: VarName) + vns = collect(VarName, vns) + end + return invlink!!(t, vi, vns, model) end function _default_sampler(context::AbstractContext) From 474985376a2c6788b2f36cadabc0e9b4db4cbebb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:09:48 +0000 Subject: [PATCH 08/40] Fix invlink/link for TypedVarInfo and StaticTransformation --- src/varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 527ac2dc1..e29152b8c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1160,8 +1160,8 @@ VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},Na # Specialise link!! without varnames provided for TypedVarInfo. The usual version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstread calls to link!! type stable. -function link!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) +# keep the downstream calls to link!! type stable. +function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) return link!!(t, vi, all_varnames_namedtuple(vi), model) end @@ -1273,8 +1273,8 @@ end # Specialise invlink!! without varnames provided for TypedVarInfo. The usual version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstread calls to link!! type stable. -function invlink!!(t::AbstractTransformation, vi::TypedVarInfo, model::Model) +# keep the downstream calls to link!! type stable. +function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) return invlink!!(t, vi, all_varnames_namedtuple(vi), model) end From e960679a1d7e97dec87dc67fafdfb8ddde242a7f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:09:59 +0000 Subject: [PATCH 09/40] Fix a test in models.jl --- test/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/model.jl b/test/model.jl index 45c770cc4..a9d0b160f 100644 --- a/test/model.jl +++ b/test/model.jl @@ -226,7 +226,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = DynamicPPL.TestUtils.demo_dynamic_constraint() spl = SampleFromPrior() vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) - link!!(vi, spl, model) + vi = link!!(vi, model) for i in 1:10 # Sample with large variations. From d507a535521bd43c2c032982ab3049275c3ff119 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:37:55 +0000 Subject: [PATCH 10/40] Move some functions to utils.jl, add tests and docstrings --- src/utils.jl | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/varinfo.jl | 42 ---------------------------------- test/utils.jl | 29 ++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 42 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 5fedd3039..b64ae46cc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1268,3 +1268,64 @@ _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, right::NamedTuple{()}) = left _merge(left::NamedTuple{()}, right::AbstractDict) = right + +""" + unique_syms(vns::T) where {T<:NTuple{N,VarName}} + +Return the unique symbols of the variables in `vns`. + +Note that `unique_syms` is only defined for `Tuple`s of `VarName`s. For a `Vector` you can +just use `Base.unique`. The point of `unique_syms` is that it supports constant propagating +the result, which is possible with a `Tuple` but `Base.unique` won't allow it. +""" +@generated function unique_syms(::T) where {T<:NTuple{N,VarName}} where {N} + retval = Expr(:tuple) + syms = [first(vn.parameters) for vn in T.parameters] + for sym in unique(syms) + push!(retval.args, QuoteNode(sym)) + end + return retval +end + +""" + varname_namedtuple(vns::NTuple{N,VarName}) where {N} + varname_namedtuple(vns::AbstractVector{<:VarName}) + varname_namedtuple(vns::NamedTuple) + +Return a `NamedTuple` of the variables in `vns` grouped by symbol. + +`varname_namedtuple` is type table for inputs that are `Tuple`s, and for vectors when all +`VarName`s in the vector have the same symbol. For a `NamedTuple` it's a no-op. + +Example: +```julia +julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) +(x, y[1], x.a, z[15], y[2]) + +julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) +(x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) + +julia> varname_namedtuple(vns_tuple) == vns_nt +``` +""" +function varname_namedtuple(vns::NTuple{N,VarName} where {N}) + syms = unique_syms(vns) + elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) + return NamedTuple{syms}(elements) +end + +# This method is type unstable, but that can't be helped: The problem is inherently type +# unstable if there are VarNames with multiple symbols in a Vector. +function varname_namedtuple(vns::AbstractVector{<:VarName}) + syms = tuple(unique(map(getsym, vns))...) + elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) + return NamedTuple{syms}(elements) +end + +# A simpler, type stable implementation when all the VarNames in a Vector have the same +# symbol. +function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} + return NamedTuple{(T,)}((vns,)) +end + +varname_namedtuple(vns::NamedTuple) = vns diff --git a/src/varinfo.jl b/src/varinfo.jl index e29152b8c..74a6e3b8d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1411,48 +1411,6 @@ function _link( ) end -""" - unique_syms(vns::T) where {T<:NTuple{N,VarName}} - -Return the unique symbols of the variables in `vns`. -""" -@generated function unique_syms(vns::T) where {T<:NTuple{N,VarName}} where {N} - retval = Expr(:tuple) - syms = [first(vn.parameters) for vn in T.parameters] - for sym in unique(syms) - push!(retval.args, QuoteNode(sym)) - end - return retval -end - -""" - varname_namedtuple(vns::NTuple{N,VarName}) where {N} - varname_namedtuple(vns::AbstractVector{<:VarName}) - -Return a `NamedTuple` of the variables in `vns` grouped by symbol. -""" -function varname_namedtuple(vns::NTuple{N,VarName} where {N}) - syms = unique_syms(vns) - elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) - return NamedTuple{syms}(elements) -end - -# This method is type unstable, but that can't be helped: The problem is inherently type -# unstable if there are VarNames with multiple symbols in a Vector. -function varname_namedtuple(vns::AbstractVector{<:VarName}) - syms = tuple(unique(map(getsym, vns))...) - elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) - return NamedTuple{syms}(elements) -end - -# A simpler, type stable implementation when all the VarNames in a Vector have the same -# symbol. -function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} - return NamedTuple{(T,)}((vns,)) -end - -varname_namedtuple(vns::NamedTuple) = vns - """ all_varnames_namedtuple(vi::AbstractVarInfo) diff --git a/test/utils.jl b/test/utils.jl index 3f435dca4..af7b3ee4d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,4 +48,33 @@ x = rand(dist) @test DynamicPPL.tovec(x) == vec(x.UL) end + + @testset "unique_syms" begin + vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) + @inferred DynamicPPL.unique_syms(vns) + @inferred DynamicPPL.unique_syms(()) + @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) + @test DynamicPPL.unique_syms(()) == () + end + + @testset "varname_namedtuple" begin + vns_tuple = ( + @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) + ) + vns_vec = collect(vns_tuple) + vns_nt = (; + x=[@varname(x), @varname(x.a)], + y=[@varname(y[1]), @varname(y[2])], + z=[@varname(z[15])], + ) + vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] + @inferred DynamicPPL.varname_namedtuple(vns_tuple) + @inferred DynamicPPL.varname_namedtuple(vns_nt) + @inferred DynamicPPL.varname_namedtuple(vns_vec_single_symbol) + @test DynamicPPL.varname_namedtuple(vns_tuple) == vns_nt + @test DynamicPPL.varname_namedtuple(vns_vec) == vns_nt + @test DynamicPPL.varname_namedtuple(vns_nt) == vns_nt + @test DynamicPPL.varname_namedtuple(vns_vec_single_symbol) == + (; x=vns_vec_single_symbol) + end end From 41150b5e5fda23a19fdd2ecff1fb0a6847936256 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 12:54:42 +0000 Subject: [PATCH 11/40] Fix a docstring typo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 74a6e3b8d..9b104a9a6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1412,7 +1412,7 @@ function _link( end """ - all_varnames_namedtuple(vi::AbstractVarInfo) + all_varnames_namedtuple(vi::TypedVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ From 45d1f137dd76f2030594cad97263435a71fad346 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 13:28:48 +0000 Subject: [PATCH 12/40] Various simplification to link/invlink --- src/abstract_varinfo.jl | 20 ++++----- src/simple_varinfo.jl | 4 +- src/threadsafe.jl | 5 --- src/utils.jl | 3 ++ src/varinfo.jl | 92 +++++++++++++++++------------------------ 5 files changed, 53 insertions(+), 71 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 6b7e412a8..b4aed0458 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -540,7 +540,7 @@ function settrans!! end """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their linked space, using the transformation `t`, @@ -561,6 +561,7 @@ end function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. + # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -574,7 +575,7 @@ end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. @@ -594,6 +595,7 @@ end function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. + # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -607,7 +609,7 @@ end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their constrained space, using the (inverse of) @@ -644,7 +646,7 @@ end function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + vns::VarNameCollection, ::Model, ) # TODO(mhauru) The behavior of this before the removal of indexing with samplers was a @@ -669,7 +671,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, - vns::Union{NTuple{N,VarName} where {N},AbstractVector{<:VarName}}, + vns::VarNameCollection, ::Model, ) # TODO(mhauru) See comment in link!! above. @@ -690,7 +692,7 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N, <:VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) @@ -793,11 +795,7 @@ function maybe_invlink_before_eval!!( t::StaticTransformation, vi::AbstractVarInfo, ::AbstractContext, model::Model ) # TODO(mhauru) Why does this function need the context argument? - vns = collect(keys(vi)) - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return invlink!!(t, vi, vns, model) + return invlink!!(t, vi, model) end function _default_sampler(context::AbstractContext) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b4e836371..f60c0b0fb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,7 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, + ::VarNameCollection, model::Model, ) # TODO: Make sure that `spl` is respected. @@ -695,7 +695,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}}, + ::VarNameCollection, model::Model, ) # TODO: Make sure that `spl` is respected. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c75ec2291..25aa0d654 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,11 +81,6 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -SamplerOrVarNameIterator = Union{ - AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} - function link!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model ) diff --git a/src/utils.jl b/src/utils.jl index b64ae46cc..854ead3fd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,9 @@ struct NoDefault end const NO_DEFAULT = NoDefault() +# A short-hand for a type commonly used in type signatures for VarInfo methods. +VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} + """ @addlogprob!(ex) diff --git a/src/varinfo.jl b/src/varinfo.jl index 9b104a9a6..97fb733e2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1153,16 +1153,12 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -SamplerOrVarNameIterator = Union{ - AbstractSampler,NTuple{N,VarName} where N,AbstractVector{<:VarName} -} -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} - -# Specialise link!! without varnames provided for TypedVarInfo. The usual version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstream calls to link!! type stable. -function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return link!!(t, vi, all_varnames_namedtuple(vi), model) +# Specialise link!! without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream call to _link! type stable. +function link!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) + _link!(vi, all_varnames_namedtuple(vi)) + return vi end # X -> R for all variables associated with given sampler @@ -1185,9 +1181,7 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!( - vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} -) +function _link!(vi::UntypedVarInfo, vns::VarNameCollection) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns @@ -1209,35 +1203,8 @@ function filter_subsumed(vns1, vns2) return filter(x -> any(subsumes(y, x) for y in vns1), vns2) end -function _link!(vi::TypedVarInfo, vns::VarNameCollection) - return _link!(vi.metadata, vi, vns) -end -@generated function _link!( - ::NamedTuple{names}, vi, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} -) where {names} - expr = Expr(:block) - for f in names - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns, f_vns) - if !isempty(f_vns) - if !istrans(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end - end - end, - ) - end - return expr +function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) + return _link!(vi.metadata, vi, varname_namedtuple(vns)) end @generated function _link!( @@ -1271,11 +1238,12 @@ end return expr end -# Specialise invlink!! without varnames provided for TypedVarInfo. The usual version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, that helps -# keep the downstream calls to link!! type stable. -function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return invlink!!(t, vi, all_varnames_namedtuple(vi), model) +# Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream calls to link!! type stable. +function invlink!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) + _invlink!(vi, all_varnames_namedtuple(vi)) + return vi end # R -> X for all variables associated with given sampler @@ -1308,9 +1276,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -function _invlink!( - vi::UntypedVarInfo, vns::Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} -) +function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1322,7 +1288,7 @@ function _invlink!( end end -function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) +function _invlink!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) vns_namedtuple = varname_namedtuple(vns) return _invlink!(vi.metadata, vi, vns_namedtuple) end @@ -1400,6 +1366,13 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end +# Specialise link without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream calls to link!! type stable. +function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _link(model, vi, all_varnames_namedtuple(vi)) +end + function _link( model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, vns::VarNameCollection ) @@ -1426,7 +1399,9 @@ all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) return expr end -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _link( + model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} +) varinfo = deepcopy(varinfo) vns_namedtuple = varname_namedtuple(vns) md = _link_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) @@ -1450,6 +1425,7 @@ end return :(NamedTuple{$metadata_names}($vals)) end + function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns @@ -1527,6 +1503,7 @@ function invlink( ) return _invlink(model, varinfo, vns) end + function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, @@ -1538,6 +1515,13 @@ function invlink( return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end +# Specialise invlink without varnames provided for TypedVarInfo. The generic version gets +# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which +# helps keep the downstream calls to link!! type stable. +function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _invlink(model, vi, all_varnames_namedtuple(vi)) +end + function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) varinfo = deepcopy(varinfo) return VarInfo( @@ -1547,7 +1531,9 @@ function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) ) end -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _invlink( + model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} +) varinfo = deepcopy(varinfo) vns_namedtuple = varname_namedtuple(vns) md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) From 98915c2d5751f45287cfaa0bc1620f3098f6bf78 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 13:35:00 +0000 Subject: [PATCH 13/40] Improve a docstring --- src/utils.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 854ead3fd..16aa38e4a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1277,9 +1277,10 @@ _merge(left::NamedTuple{()}, right::AbstractDict) = right Return the unique symbols of the variables in `vns`. -Note that `unique_syms` is only defined for `Tuple`s of `VarName`s. For a `Vector` you can -just use `Base.unique`. The point of `unique_syms` is that it supports constant propagating -the result, which is possible with a `Tuple` but `Base.unique` won't allow it. +Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike +`Base.unique`, returns a `Tuple`. For an `AbstractVector{<:VarName}` you can use +`Base.unique`. The point of `unique_syms` is that it supports constant propagating +the result, which is possible only when the argument and the return value are `Tuple`s. """ @generated function unique_syms(::T) where {T<:NTuple{N,VarName}} where {N} retval = Expr(:tuple) From f05068daba935ec974fbbe2b1418940b95b0ca20 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 13:48:24 +0000 Subject: [PATCH 14/40] Style improvements --- src/varinfo.jl | 53 +++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 97fb733e2..3f9d817b7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -900,6 +900,21 @@ end return :($(exprs...),) end +""" + all_varnames_namedtuple(vi::TypedVarInfo) + +Return a `NamedTuple` of the variables in `vi` grouped by symbol. +""" +all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) + +@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} + expr = Expr(:tuple) + for f in names + push!(expr.args, :($f = keys(md.$f))) + end + return expr +end + # Get the index (in vals) ranges of all the vns of variables belonging to spl @inline function _getranges(vi::VarInfo, spl::Sampler) ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} @@ -1194,17 +1209,17 @@ function _link!(vi::UntypedVarInfo, vns::VarNameCollection) end end -""" - filter_subsumed(vns1, vns2) +function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) + return _link!(vi.metadata, vi, varname_namedtuple(vns)) +end -Return the subset of `vns2` that are subsumed by any variable in `vns1`. """ -function filter_subsumed(vns1, vns2) - return filter(x -> any(subsumes(y, x) for y in vns1), vns2) -end + filter_subsumed(filter_vns, filtered_vns) -function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) - return _link!(vi.metadata, vi, varname_namedtuple(vns)) +Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`. +""" +function filter_subsumed(filter_vns, filtered_vns) + return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end @generated function _link!( @@ -1240,7 +1255,7 @@ end # Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream calls to link!! type stable. +# helps keep the downstream call to _invlink! type stable. function invlink!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) _invlink!(vi, all_varnames_namedtuple(vi)) return vi @@ -1292,6 +1307,7 @@ function _invlink!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) vns_namedtuple = varname_namedtuple(vns) return _invlink!(vi.metadata, vi, vns_namedtuple) end + @generated function _invlink!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} @@ -1368,7 +1384,7 @@ end # Specialise link without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream calls to link!! type stable. +# helps keep the downstream call to _link type stable. function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) return _link(model, vi, all_varnames_namedtuple(vi)) end @@ -1384,21 +1400,6 @@ function _link( ) end -""" - all_varnames_namedtuple(vi::TypedVarInfo) - -Return a `NamedTuple` of the variables in `vi` grouped by symbol. -""" -all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) - -@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :($f = keys(md.$f))) - end - return expr -end - function _link( model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} ) @@ -1517,7 +1518,7 @@ end # Specialise invlink without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream calls to link!! type stable. +# helps keep the downstream call to _invlink type stable. function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) return _invlink(model, vi, all_varnames_namedtuple(vi)) end From bc4c42093dafe01c0d7ed3984232e471a1bdcb65 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 14:52:15 +0000 Subject: [PATCH 15/40] Fix broken link/invlink dispatch cascade for VectorVarInfo --- src/varinfo.jl | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3f9d817b7..f03926051 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1170,14 +1170,18 @@ end # Specialise link!! without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _link! type stable. -function link!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) - _link!(vi, all_varnames_namedtuple(vi)) - return vi +# helps keep the downstream call to link!! type stable. +function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) + return link!!(t, vi, all_varnames_namedtuple(vi), model) end # X -> R for all variables associated with given sampler -function link!!(t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) +function link!!( + t::DynamicTransformation, + vi::VarInfo, + vns::Union{VarNameCollection,NamedTuple}, + model::Model, +) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return link(t, vi, vns, model) # Call `_link!` instead of `link!` to avoid deprecation warning. @@ -1255,15 +1259,17 @@ end # Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets # the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _invlink! type stable. -function invlink!!(::DynamicTransformation, vi::TypedVarInfo, ::Model) - _invlink!(vi, all_varnames_namedtuple(vi)) - return vi +# helps keep the downstream call to invlink!! type stable. +function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) + return invlink!!(t, vi, all_varnames_namedtuple(vi), model) end # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model + t::DynamicTransformation, + vi::VarInfo, + vns::Union{VarNameCollection,NamedTuple}, + model::Model, ) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return invlink(t, vi, vns, model) From 71980baf556c86a2a335a8376b075e726de30f78 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 15:36:05 +0000 Subject: [PATCH 16/40] Fix some more broken dispatch cascades --- src/varinfo.jl | 51 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f03926051..8b835014d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1183,7 +1183,7 @@ function link!!( model::Model, ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return link(t, vi, vns, model) + has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. _link!(vi, vns) return vi @@ -1213,8 +1213,14 @@ function _link!(vi::UntypedVarInfo, vns::VarNameCollection) end end -function _link!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) - return _link!(vi.metadata, vi, varname_namedtuple(vns)) +# If we try to _link! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _link!(vi::TypedVarInfo, vns::VarNameCollection) + return _link!(vi, varname_namedtuple(vns)) +end + +function _link!(vi::TypedVarInfo, vns::NamedTuple) + return _link!(vi.metadata, vi, vns) end """ @@ -1272,7 +1278,7 @@ function invlink!!( model::Model, ) # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return invlink(t, vi, vns, model) + has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. _invlink!(vi, vns) return vi @@ -1309,9 +1315,14 @@ function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) end end -function _invlink!(vi::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple}) - vns_namedtuple = varname_namedtuple(vns) - return _invlink!(vi.metadata, vi, vns_namedtuple) +# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) + return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) +end + +function _invlink!(vi::TypedVarInfo, vns::NamedTuple) + return _invlink!(vi.metadata, vi, vns) end @generated function _invlink!( @@ -1406,12 +1417,15 @@ function _link( ) end -function _link( - model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} -) +# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) + return _link(model, varinfo, varname_namedtuple(vns)) +end + +function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - vns_namedtuple = varname_namedtuple(vns) - md = _link_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) + md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @@ -1538,12 +1552,15 @@ function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) ) end -function _invlink( - model::Model, varinfo::TypedVarInfo, vns::Union{VarNameCollection,NamedTuple} -) +# If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# it to a NamedTuple that matches the structure of the TypedVarInfo. +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) + return _invlink(model, varinfo, varname_namedtuple(vns)) +end + +function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - vns_namedtuple = varname_namedtuple(vns) - md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns_namedtuple) + md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end From 45562a9cacca75439cb422b34e8bc7f011d02090 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 24 Jan 2025 14:34:26 +0000 Subject: [PATCH 17/40] Apply suggestions from code review Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- src/abstract_varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index b4aed0458..a215bbd14 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -561,7 +561,7 @@ end function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? + # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -656,7 +656,7 @@ function link!!( # loudly. all_vns = Set(keys(vi)) if Set(vns) != all_vns - msg = "StaticTransforming only a subset of variables is not supported." + msg = "Statically transforming only a subset of variables is not supported." throw(ArgumentError(msg)) end b = inverse(t.bijector) From db5b8357316f6004e2e10a67ee11d2638e4cdfec Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 24 Jan 2025 14:36:30 +0000 Subject: [PATCH 18/40] Remove comments that messed with docstrings --- src/abstract_varinfo.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index a215bbd14..891218fb6 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -550,7 +550,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end @@ -584,7 +583,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end @@ -619,7 +617,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end @@ -702,7 +699,6 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`link`](@ref). """ -# Use `default_transformation` to decide which transformation to use if none is specified. function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end From f99effe14ed5189e8b552984ff29b6cf5e56c6b6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 12:32:51 +0000 Subject: [PATCH 19/40] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/abstract_varinfo.jl | 4 ++-- src/utils.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 891218fb6..c8a2ff17b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -593,7 +593,7 @@ end function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix akeys` so that it would always return VarName[]? + # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -674,7 +674,7 @@ function invlink!!( # TODO(mhauru) See comment in link!! above. all_vns = Set(keys(vi)) if Set(vns) != all_vns - msg = "StaticTransforming only a subset of variables is not supported." + msg = "Statically transforming only a subset of variables is not supported." throw(ArgumentError(msg)) end b = t.bijector diff --git a/src/utils.jl b/src/utils.jl index 16aa38e4a..307bf1f85 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1298,7 +1298,7 @@ end Return a `NamedTuple` of the variables in `vns` grouped by symbol. -`varname_namedtuple` is type table for inputs that are `Tuple`s, and for vectors when all +`varname_namedtuple` is type stable for inputs that are `Tuple`s, and for vectors when all `VarName`s in the vector have the same symbol. For a `NamedTuple` it's a no-op. Example: From 56194cd000636bdde0710011f250795430971667 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 13:09:01 +0000 Subject: [PATCH 20/40] Fix issues surfaced in code review --- docs/src/api.md | 2 +- src/DynamicPPL.jl | 2 +- src/abstract_varinfo.jl | 2 ++ src/threadsafe.jl | 4 ++-- src/transforming.jl | 2 -- src/utils.jl | 2 -- src/varinfo.jl | 1 - test/varinfo.jl | 11 +++++++++++ 8 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 093cb06a6..36dd24250 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -304,7 +304,7 @@ set_num_produce! increment_num_produce! reset_num_produce! setorder! -set_retained_vns_del_by_spl! +set_retained_vns_del! ``` ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1cdbd94e..55e1f7e88 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -59,7 +59,7 @@ export AbstractVarInfo, set_num_produce!, reset_num_produce!, increment_num_produce!, - set_retained_vns_del_by_spl!, + set_retained_vns_del!, is_flagged, set_flag!, unset_flag!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c8a2ff17b..c59e2990c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -561,6 +561,7 @@ function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? + # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end @@ -594,6 +595,7 @@ function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) vns = collect(keys(vi)) # In case e.g. vns = Any[]. # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? + # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. if !(eltype(vns) <: VarName) vns = collect(VarName, vns) end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 25aa0d654..fae0c1613 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -182,8 +182,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) return vector_getranges(vi.varinfo, vns) end -function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) - return set_retained_vns_del_by_spl!(vi.varinfo, spl) +function set_retained_vns_del!(vi::ThreadSafeVarInfo, spl::Sampler) + return set_retained_vns_del!(vi.varinfo, spl) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) diff --git a/src/transforming.jl b/src/transforming.jl index 46b42d8ed..f3f4fbba0 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,8 +91,6 @@ function dot_tilde_assume( return r, lp, vi end -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName},NamedTuple} - function link!!( t::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model ) diff --git a/src/utils.jl b/src/utils.jl index 307bf1f85..0bf9d6d3d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1331,5 +1331,3 @@ end function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} return NamedTuple{(T,)}((vns,)) end - -varname_namedtuple(vns::NamedTuple) = vns diff --git a/src/varinfo.jl b/src/varinfo.jl index 8b835014d..c49a6ffc3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -2073,7 +2073,6 @@ function set_retained_vns_del!(vi::UntypedVarInfo) return nothing end function set_retained_vns_del!(vi::TypedVarInfo) - # Get the indices of `vns` that belong to `spl` as a NamedTuple, one entry for each symbol idcs = _getidcs(vi) return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end diff --git a/test/varinfo.jl b/test/varinfo.jl index fd1c9a2e9..99d319425 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -825,6 +825,17 @@ end end end + # The below used to error, testing to avoid regression. + @testset "merge different dimensions" begin + vn = @varname(x) + vi_single = VarInfo() + vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_double = VarInfo() + vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] + @test merge(vi_double, vi_single)[vn] == 1.0 + end + @testset "sampling from linked varinfo" begin # `~` @model function demo(n=1) From c187c49152a619ec663961c90102509fbd8482ae Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:16:11 +0000 Subject: [PATCH 21/40] Simplify link/invlink arguments --- src/abstract_varinfo.jl | 146 ++++++++++------------------------------ src/simple_varinfo.jl | 6 +- src/threadsafe.jl | 44 ++++-------- src/transforming.jl | 20 ++---- src/utils.jl | 2 +- src/varinfo.jl | 132 ++++++++++++++++++++++-------------- test/varinfo.jl | 4 +- 7 files changed, 145 insertions(+), 209 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c59e2990c..c7afc67a5 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -537,127 +537,77 @@ If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variabl """ function settrans!! end +# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback +# method for the case when no `vns` is provided, that would get all the keys from the +# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case +# where `vns` is provided and the one where it is not. This is because having separate +# implementations is typically much more performant, and because not all AbstractVarInfo +# types support partial linking. + """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space, mutating `vi` if possible. -Transform the variables in `vi` to their linked space, using the transformation `t`, -mutating `vi` if possible. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink!!`](@ref). """ function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, vns, model::Model) +function link!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, link all variables. -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? - # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return link!!(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function link!!(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link!!(t, vi, (vn,), model) -end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their linked space without mutating `vi`. -Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, vns, model::Model) +function link(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return link(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, link all variables. -function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - # TODO(mhauru) Could we rather fix `keys` so that it would always return VarName[]? - # See https://github.com/TuringLang/DynamicPPL.jl/issues/791. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return link(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function link(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return link(t, vi, (vn,), model) -end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their constrained space, mutating `vi` if possible. -Transform the variables in `vi` to their constrained space, using the (inverse of) -transformation `t`, mutating `vi` if possible. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link!!`](@ref). """ function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, vns, model::Model) +function invlink!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, invlink!! all variables. -function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return invlink!!(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function invlink!!( - t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model -) - return invlink!!(t, vi, (vn,), model) -end # Vector-based ones. function link!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - vns::VarNameCollection, - ::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) - # TODO(mhauru) The behavior of this before the removal of indexing with samplers was a - # bit mixed. For TypedVarInfo you could transform only a subset of the variables, but - # for UntypedVarInfo and SimpleVarInfo it was silently assumed that all variables were - # being set. Unsure if we should support this or not, but at least it now errors - # loudly. - all_vns = Set(keys(vi)) - if Set(vns) != all_vns - msg = "Statically transforming only a subset of variables is not supported." - throw(ArgumentError(msg)) - end b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) @@ -668,17 +618,8 @@ function link!!( end function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, - vi::AbstractVarInfo, - vns::VarNameCollection, - ::Model, + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) - # TODO(mhauru) See comment in link!! above. - all_vns = Set(keys(vi)) - if Set(vns) != all_vns - msg = "Statically transforming only a subset of variables is not supported." - throw(ArgumentError(msg)) - end b = t.bijector y = vi[:] x, logjac = with_logabsdet_jacobian(b, y) @@ -690,36 +631,23 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::VarName, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::Tuple{N,VarName}, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vn::AbstractVector{<:VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + +Transform variables in `vi` to their constrained space without mutating `vi`. -Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) -transformation `t`. +Either transform all variables, or only ones specified in `vns`. -If `t` is not provided, `default_transformation(model, vi)` will be used. +Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is +not provided. See also: [`default_transformation`](@ref), [`link`](@ref). """ function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, vns, model::Model) +function invlink(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end -# If no variable names are provided, invlink all variables. -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - vns = collect(keys(vi)) - # In case e.g. vns = Any[]. - if !(eltype(vns) <: VarName) - vns = collect(VarName, vns) - end - return invlink(t, vi, vns, model) -end -# Wrap a single VarName in a singleton tuple. -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, vn::VarName, model::Model) - return invlink(t, vi, (vn,), model) -end """ maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f60c0b0fb..57b167077 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -680,8 +680,7 @@ Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarIn function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::VarNameCollection, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) @@ -695,8 +694,7 @@ end function invlink!!( t::StaticTransformation{<:Bijectors.NamedTransform}, vi::SimpleVarInfo{<:NamedTuple}, - ::VarNameCollection, - model::Model, + ::Model, ) # TODO: Make sure that `spl` is respected. b = t.bijector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index fae0c1613..bf4817fbd 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,59 +81,43 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) -function link!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, vns, model) +function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) end -function invlink!!( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, vns, model) +function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) +function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) end -function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) +function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure # consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates # to define `getlogp(vi)`. -function link!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model -) +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::ThreadSafeVarInfo, ::VarNameCollection, model::Model -) +function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return link!!(t, deepcopy(vi), vns, model) +function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::ThreadSafeVarInfo, vns::VarNameCollection, model::Model -) - return invlink!!(t, deepcopy(vi), vns, model) +function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end function maybe_invlink_before_eval!!( diff --git a/src/transforming.jl b/src/transforming.jl index f3f4fbba0..1a26d212f 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -91,29 +91,21 @@ function dot_tilde_assume( return r, lp, vi end -function link!!( - t::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model -) +function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end -function invlink!!( - ::DynamicTransformation, vi::AbstractVarInfo, ::VarNameCollection, model::Model -) +function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!( last(evaluate!!(model, vi, DynamicTransformationContext{true}())), NoTransformation(), ) end -function link( - t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model -) - return link!!(t, deepcopy(vi), vns, model) +function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) end -function invlink( - t::DynamicTransformation, vi::AbstractVarInfo, vns::VarNameCollection, model::Model -) - return invlink!!(t, deepcopy(vi), vns, model) +function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) end diff --git a/src/utils.jl b/src/utils.jl index 0bf9d6d3d..265fa773b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,7 @@ struct NoDefault end const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. -VarNameCollection = Union{NTuple{N,VarName} where N,AbstractVector{<:VarName}} +VarNameCollection = NTuple{N,VarName} where {N} """ @addlogprob!(ex) diff --git a/src/varinfo.jl b/src/varinfo.jl index c49a6ffc3..cdf67b019 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1168,20 +1168,30 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -# Specialise link!! without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to link!! type stable. -function link!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return link!!(t, vi, all_varnames_namedtuple(vi), model) +function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_namedtuple(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _link(model, vi, vns) + _link!(vi, vns) + return vi +end + +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) end # X -> R for all variables associated with given sampler -function link!!( - t::DynamicTransformation, - vi::VarInfo, - vns::Union{VarNameCollection,NamedTuple}, - model::Model, -) +function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. @@ -1195,12 +1205,12 @@ function link!!( vns::VarNameCollection, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, vns::VarNameCollection) +function _link!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns @@ -1213,7 +1223,7 @@ function _link!(vi::UntypedVarInfo, vns::VarNameCollection) end end -# If we try to _link! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _link!(vi::TypedVarInfo, vns::VarNameCollection) return _link!(vi, varname_namedtuple(vns)) @@ -1263,19 +1273,32 @@ end return expr end -# Specialise invlink!! without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to invlink!! type stable. -function invlink!!(t::DynamicTransformation, vi::TypedVarInfo, model::Model) - return invlink!!(t, vi, all_varnames_namedtuple(vi), model) +function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) + vns = all_varnames_namedtuple(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. + _invlink!(vi, vns) + return vi +end + +function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) + vns = keys(vi) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return _invlink(model, vi, vns) + _invlink!(vi, vns) + return vi +end + +function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) end # R -> X for all variables associated with given sampler function invlink!!( - t::DynamicTransformation, - vi::VarInfo, - vns::Union{VarNameCollection,NamedTuple}, - model::Model, + ::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model ) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1303,7 +1326,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) +function _invlink!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1315,7 +1338,7 @@ function _invlink!(vi::UntypedVarInfo, vns::VarNameCollection) end end -# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert +# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) @@ -1382,6 +1405,20 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) return map(Returns(nothing), varinfo.metadata) end +function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _link(model, vi, all_varnames_namedtuple(vi)) +end + +function link(::DynamicTransformation, varinfo::VarInfo, model::Model) + return _link(model, varinfo, keys(varinfo)) +end + +function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) +end + function link( ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model ) @@ -1399,22 +1436,10 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end -# Specialise link without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _link type stable. -function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _link(model, vi, all_varnames_namedtuple(vi)) -end - -function _link( - model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, vns::VarNameCollection -) +function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _link_metadata!!(model, varinfo, varinfo.metadata, vns), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end # If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert @@ -1519,6 +1544,22 @@ function _link_metadata!!( return metadata end +function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) + return _invlink(model, vi, all_varnames_namedtuple(vi)) +end + +function invlink(::DynamicTransformation, vi::VarInfo, model::Model) + return _invlink(model, vi, keys(vi)) +end + +function invlink( + ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) +end + function invlink( ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model ) @@ -1536,14 +1577,7 @@ function invlink( return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) end -# Specialise invlink without varnames provided for TypedVarInfo. The generic version gets -# the keys of `vi` as a Vector. For TypedVarInfo we can get them as a NamedTuple, which -# helps keep the downstream call to _invlink type stable. -function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _invlink(model, vi, all_varnames_namedtuple(vi)) -end - -function _invlink(model::Model, varinfo::VarInfo, vns::VarNameCollection) +function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) return VarInfo( _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), diff --git a/test/varinfo.jl b/test/varinfo.jl index 99d319425..d689a1bf4 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -437,10 +437,10 @@ end other_vns = filter(x -> !subsumes(vn, x), all_vns) @test !isempty(target_vns) @test !isempty(other_vns) - vi = link!!(vi, vn, model) + vi = link!!(vi, (vn,), model) @test all(x -> istrans(vi, x), target_vns) @test all(x -> !istrans(vi, x), other_vns) - vi = invlink!!(vi, vn, model) + vi = invlink!!(vi, (vn,), model) @test all(x -> !istrans(vi, x), all_vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 From 86b25c5ae3ab2d7e4f15816b76d64c9debad9c7f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:16:27 +0000 Subject: [PATCH 22/40] Fix a bug in unflatten VarNamedVector --- src/varnamedvector.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index b324e9134..14ef6ce6a 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1066,7 +1066,12 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) new_ranges = deepcopy(vnv.ranges) recontiguify_ranges!(new_ranges) return VarNamedVector( - vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms + vnv.varname_to_index, + vnv.varnames, + new_ranges, + vals, + vnv.transforms, + vnv.is_unconstrained, ) end From 2a6c1bcef4d14c38bb1ce5c07c868e325ba92aae Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:20:54 +0000 Subject: [PATCH 23/40] Rename VarNameCollection -> VarNameTuple --- src/abstract_varinfo.jl | 8 ++++---- src/utils.jl | 2 +- src/varinfo.jl | 30 ++++++++++++------------------ 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c7afc67a5..087affd90 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -560,7 +560,7 @@ See also: [`default_transformation`](@ref), [`invlink!!`](@ref). function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end @@ -580,7 +580,7 @@ See also: [`default_transformation`](@ref), [`invlink`](@ref). function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link(default_transformation(model, vi), vi, vns, model) end @@ -600,7 +600,7 @@ See also: [`default_transformation`](@ref), [`link!!`](@ref). function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end @@ -645,7 +645,7 @@ See also: [`default_transformation`](@ref), [`link`](@ref). function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, vns::VarNameCollection, model::Model) +function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end diff --git a/src/utils.jl b/src/utils.jl index 265fa773b..de7fe5925 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,7 @@ struct NoDefault end const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. -VarNameCollection = NTuple{N,VarName} where {N} +VarNameTuple = NTuple{N,VarName} where {N} """ @addlogprob!(ex) diff --git a/src/varinfo.jl b/src/varinfo.jl index cdf67b019..887f132c6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1191,7 +1191,7 @@ function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, mode end # X -> R for all variables associated with given sampler -function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model) +function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) # Call `_link!` instead of `link!` to avoid deprecation warning. @@ -1202,7 +1202,7 @@ end function link!!( t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, @@ -1225,7 +1225,7 @@ end # If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _link!(vi::TypedVarInfo, vns::VarNameCollection) +function _link!(vi::TypedVarInfo, vns::VarNameTuple) return _link!(vi, varname_namedtuple(vns)) end @@ -1297,9 +1297,7 @@ function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, m end # R -> X for all variables associated with given sampler -function invlink!!( - ::DynamicTransformation, vi::VarInfo, vns::VarNameCollection, model::Model -) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. @@ -1310,7 +1308,7 @@ end function invlink!!( ::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so @@ -1340,7 +1338,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _invlink!(vi::TypedVarInfo, vns::VarNameCollection) +function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) end @@ -1419,16 +1417,14 @@ function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, mo return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) end -function link( - ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model -) +function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _link(model, varinfo, vns) end function link( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so @@ -1444,7 +1440,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _link(model, varinfo, varname_namedtuple(vns)) end @@ -1560,16 +1556,14 @@ function invlink( return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) end -function invlink( - ::DynamicTransformation, varinfo::VarInfo, vns::VarNameCollection, model::Model -) +function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _invlink(model, varinfo, vns) end function invlink( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameCollection, + vns::VarNameTuple, model::Model, ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so @@ -1588,7 +1582,7 @@ end # If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameCollection) +function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, varname_namedtuple(vns)) end From 853f47e683428d02d1d87641a5bf2687d611f94c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 28 Jan 2025 17:35:13 +0000 Subject: [PATCH 24/40] Remove test of a removed varname_namedtuple method --- test/utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index af7b3ee4d..cdb2af4f7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -69,11 +69,9 @@ ) vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] @inferred DynamicPPL.varname_namedtuple(vns_tuple) - @inferred DynamicPPL.varname_namedtuple(vns_nt) @inferred DynamicPPL.varname_namedtuple(vns_vec_single_symbol) @test DynamicPPL.varname_namedtuple(vns_tuple) == vns_nt @test DynamicPPL.varname_namedtuple(vns_vec) == vns_nt - @test DynamicPPL.varname_namedtuple(vns_nt) == vns_nt @test DynamicPPL.varname_namedtuple(vns_vec_single_symbol) == (; x=vns_vec_single_symbol) end From ed803281da7c4c851be2b86f5a1708ac183a714b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 16:50:00 +0000 Subject: [PATCH 25/40] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/abstract_varinfo.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 087affd90..26785a387 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -546,8 +546,7 @@ function settrans!! end """ link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their linked space, mutating `vi` if possible. @@ -566,8 +565,7 @@ end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their linked space without mutating `vi`. @@ -586,7 +584,7 @@ end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their constrained space, mutating `vi` if possible. @@ -631,7 +629,7 @@ end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::Tuple{N,VarName}, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) Transform variables in `vi` to their constrained space without mutating `vi`. From d996d0cb45e4d959ca4ac17a329f7bd11de28954 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 17:15:55 +0000 Subject: [PATCH 26/40] Respond to review feedback --- src/utils.jl | 34 +++++++++------------------------- src/varinfo.jl | 23 ++++++++++++----------- test/utils.jl | 10 +++------- 3 files changed, 24 insertions(+), 43 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index de7fe5925..2539b7179 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1278,11 +1278,11 @@ _merge(left::NamedTuple{()}, right::AbstractDict) = right Return the unique symbols of the variables in `vns`. Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike -`Base.unique`, returns a `Tuple`. For an `AbstractVector{<:VarName}` you can use -`Base.unique`. The point of `unique_syms` is that it supports constant propagating -the result, which is possible only when the argument and the return value are `Tuple`s. +`Base.unique`, returns a `Tuple`. The point of `unique_syms` is that it supports constant +propagating the result, which is possible only when the argument and the return value are +`Tuple`s. """ -@generated function unique_syms(::T) where {T<:NTuple{N,VarName}} where {N} +@generated function unique_syms(::T) where {T<:VarNameTuple} retval = Expr(:tuple) syms = [first(vn.parameters) for vn in T.parameters] for sym in unique(syms) @@ -1292,14 +1292,12 @@ the result, which is possible only when the argument and the return value are `T end """ - varname_namedtuple(vns::NTuple{N,VarName}) where {N} - varname_namedtuple(vns::AbstractVector{<:VarName}) - varname_namedtuple(vns::NamedTuple) + group_varnames_by_symbol(vns::NTuple{N,VarName}) where {N} Return a `NamedTuple` of the variables in `vns` grouped by symbol. -`varname_namedtuple` is type stable for inputs that are `Tuple`s, and for vectors when all -`VarName`s in the vector have the same symbol. For a `NamedTuple` it's a no-op. +Note that `group_varnames_by_symbol` only accepts a `Tuple` of `VarName`s. This allows it to +be type stable. Example: ```julia @@ -1309,25 +1307,11 @@ julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) (x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) -julia> varname_namedtuple(vns_tuple) == vns_nt +julia> group_varnames_by_symbol(vns_tuple) == vns_nt ``` """ -function varname_namedtuple(vns::NTuple{N,VarName} where {N}) +function group_varnames_by_symbol(vns::VarNameTuple) syms = unique_syms(vns) elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) return NamedTuple{syms}(elements) end - -# This method is type unstable, but that can't be helped: The problem is inherently type -# unstable if there are VarNames with multiple symbols in a Vector. -function varname_namedtuple(vns::AbstractVector{<:VarName}) - syms = tuple(unique(map(getsym, vns))...) - elements = tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...) - return NamedTuple{syms}(elements) -end - -# A simpler, type stable implementation when all the VarNames in a Vector have the same -# symbol. -function varname_namedtuple(vns::AbstractVector{<:VarName{T}}) where {T} - return NamedTuple{(T,)}((vns,)) -end diff --git a/src/varinfo.jl b/src/varinfo.jl index 887f132c6..8836962e5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -901,13 +901,14 @@ end end """ - all_varnames_namedtuple(vi::TypedVarInfo) + all_varnames_grouped_by_symbol(vi::TypedVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ -all_varnames_namedtuple(vi::TypedVarInfo) = all_varnames_namedtuple(vi.metadata) +all_varnames_grouped_by_symbol(vi::TypedVarInfo) = + all_varnames_grouped_by_symbol(vi.metadata) -@generated function all_varnames_namedtuple(md::NamedTuple{names}) where {names} +@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} expr = Expr(:tuple) for f in names push!(expr.args, :($f = keys(md.$f))) @@ -1169,7 +1170,7 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) end function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) - vns = all_varnames_namedtuple(vi) + vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) _link!(vi, vns) @@ -1226,7 +1227,7 @@ end # If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _link!(vi::TypedVarInfo, vns::VarNameTuple) - return _link!(vi, varname_namedtuple(vns)) + return _link!(vi, group_varnames_by_symbol(vns)) end function _link!(vi::TypedVarInfo, vns::NamedTuple) @@ -1274,7 +1275,7 @@ end end function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) - vns = all_varnames_namedtuple(vi) + vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. @@ -1339,7 +1340,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) - return _invlink!(vi.metadata, vi, varname_namedtuple(vns)) + return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end function _invlink!(vi::TypedVarInfo, vns::NamedTuple) @@ -1404,7 +1405,7 @@ function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) end function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _link(model, vi, all_varnames_namedtuple(vi)) + return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end function link(::DynamicTransformation, varinfo::VarInfo, model::Model) @@ -1441,7 +1442,7 @@ end # If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) - return _link(model, varinfo, varname_namedtuple(vns)) + return _link(model, varinfo, group_varnames_by_symbol(vns)) end function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) @@ -1541,7 +1542,7 @@ function _link_metadata!!( end function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) - return _invlink(model, vi, all_varnames_namedtuple(vi)) + return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) end function invlink(::DynamicTransformation, vi::VarInfo, model::Model) @@ -1583,7 +1584,7 @@ end # If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert # it to a NamedTuple that matches the structure of the TypedVarInfo. function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) - return _invlink(model, varinfo, varname_namedtuple(vns)) + return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) diff --git a/test/utils.jl b/test/utils.jl index cdb2af4f7..d683f132d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -57,7 +57,7 @@ @test DynamicPPL.unique_syms(()) == () end - @testset "varname_namedtuple" begin + @testset "group_varnames_by_symbol" begin vns_tuple = ( @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) ) @@ -68,11 +68,7 @@ z=[@varname(z[15])], ) vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] - @inferred DynamicPPL.varname_namedtuple(vns_tuple) - @inferred DynamicPPL.varname_namedtuple(vns_vec_single_symbol) - @test DynamicPPL.varname_namedtuple(vns_tuple) == vns_nt - @test DynamicPPL.varname_namedtuple(vns_vec) == vns_nt - @test DynamicPPL.varname_namedtuple(vns_vec_single_symbol) == - (; x=vns_vec_single_symbol) + @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) + @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt end end From 20831485880a7503792f78794b3d4f9751f42c84 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 17:19:05 +0000 Subject: [PATCH 27/40] Remove _default_sampler and a dead argument of maybe_invlink_before_eval --- src/abstract_varinfo.jl | 28 ++++++++-------------------- src/model.jl | 8 ++++---- src/threadsafe.jl | 13 +++++-------- src/varinfo.jl | 4 ++-- test/simple_varinfo.jl | 4 +--- 5 files changed, 20 insertions(+), 37 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 26785a387..26c4268d8 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -648,7 +648,7 @@ function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) end """ - maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) + maybe_invlink_before_eval!!([t::Transformation,] vi, model) Return a possibly invlinked version of `vi`. @@ -699,37 +699,25 @@ julia> # Now performs a single `invlink!!` before model evaluation. -1001.4189385332047 ``` """ -function maybe_invlink_before_eval!!( - vi::AbstractVarInfo, context::AbstractContext, model::Model -) - return maybe_invlink_before_eval!!(transformation(vi), vi, context, model) +function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model) + return maybe_invlink_before_eval!!(transformation(vi), vi, model) end -function maybe_invlink_before_eval!!( - ::NoTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(::NoTransformation, vi::AbstractVarInfo, model::Model) return vi end function maybe_invlink_before_eval!!( - ::DynamicTransformation, vi::AbstractVarInfo, context::AbstractContext, model::Model + ::DynamicTransformation, vi::AbstractVarInfo, model::Model ) - # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we do nothing. + # `DynamicTransformation` is meant to _not_ do the transformation statically, hence we + # do nothing. return vi end function maybe_invlink_before_eval!!( - t::StaticTransformation, vi::AbstractVarInfo, ::AbstractContext, model::Model + t::StaticTransformation, vi::AbstractVarInfo, model::Model ) - # TODO(mhauru) Why does this function need the context argument? return invlink!!(t, vi, model) end -function _default_sampler(context::AbstractContext) - return _default_sampler(NodeTrait(_default_sampler, context), context) -end -_default_sampler(::IsLeaf, context::AbstractContext) = SampleFromPrior() -function _default_sampler(::IsParent, context::AbstractContext) - return _default_sampler(childcontext(context)) -end - # Utilities """ unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector) diff --git a/src/model.jl b/src/model.jl index 6fb0b40b0..462db7397 100644 --- a/src/model.jl +++ b/src/model.jl @@ -971,7 +971,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # lazy `invlink`-ing of the parameters. This can be useful for # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. - maybe_invlink_before_eval!!(varinfo, context_new, model), + maybe_invlink_before_eval!!(varinfo, model), context_new, $(unwrap_args...), ) @@ -1169,10 +1169,10 @@ end """ predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches +Generate samples from the posterior predictive distribution by evaluating `model` at each set +of parameter values provided in `chain`. The number of posterior predictive samples matches the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. +and the predicted values. """ function predict( rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} diff --git a/src/threadsafe.jl b/src/threadsafe.jl index bf4817fbd..b4403c46f 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -120,15 +120,12 @@ function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return invlink!!(t, deepcopy(vi), model) end -function maybe_invlink_before_eval!!( - vi::ThreadSafeVarInfo, context::AbstractContext, model::Model -) +function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` - # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. - return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!( - vi.varinfo, context, model - ) + # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the + # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogp(vi)`. + return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end # `getindex` diff --git a/src/varinfo.jl b/src/varinfo.jl index 8836962e5..9516745f2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1317,12 +1317,12 @@ function invlink!!( return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) end -function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) +function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do # is just assume that `default_transformation` is the correct one if `istrans(vi)`. t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() - return maybe_invlink_before_eval!!(t, vi, context, model) + return maybe_invlink_before_eval!!(t, vi, model) end function _invlink!(vi::UntypedVarInfo, vns) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 4343563eb..137c791c2 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -275,9 +275,7 @@ # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. @test !DynamicPPL.istrans( - DynamicPPL.maybe_invlink_before_eval!!( - deepcopy(vi), SamplingContext(), model - ), + DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) ) # Resulting varinfo should no longer be transformed. From 39fa6476ad2a3c76b65efecf0c20d04d2cfd4401 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 29 Jan 2025 17:22:33 +0000 Subject: [PATCH 28/40] Fix a typo in a comment --- src/varinfo.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 9516745f2..09f5960c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1224,8 +1224,8 @@ function _link!(vi::UntypedVarInfo, vns) end end -# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _link!(vi::TypedVarInfo, vns::VarNameTuple) return _link!(vi, group_varnames_by_symbol(vns)) end @@ -1337,8 +1337,8 @@ function _invlink!(vi::UntypedVarInfo, vns) end end -# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end @@ -1428,8 +1428,8 @@ function link( vns::VarNameTuple, model::Model, ) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) end @@ -1439,8 +1439,8 @@ function _link(model::Model, varinfo::VarInfo, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -# If we try to _invlink! a TypedVarInfo with a Tuple or Vector of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _link(model, varinfo, group_varnames_by_symbol(vns)) end @@ -1581,8 +1581,8 @@ function _invlink(model::Model, varinfo::VarInfo, vns) ) end -# If we try to _invlink a TypedVarInfo with a Tuple or Vector of VarNames, first convert -# it to a NamedTuple that matches the structure of the TypedVarInfo. +# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the TypedVarInfo. function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end From 2c73de570b8a9f6d8d9dcf9f8a5f5454b3f12e73 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 11:58:39 +0000 Subject: [PATCH 29/40] Add HISTORY entry, fix one set_retained_vns_del! method --- HISTORY.md | 9 +++++++++ src/threadsafe.jl | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f77d3fa74..eea7435c9 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,15 @@ **Breaking** +### Remove indexing by samplers + +This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, + + - `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`. + - `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables. + +### Reverse prefixing order + - For submodels constructed using `to_submodel`, the order in which nested prefixes are applied has been changed. Previously, the order was that outer prefixes were applied first, then inner ones. This version reverses that. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index b4403c46f..69be5dcb1 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -163,8 +163,8 @@ function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) return vector_getranges(vi.varinfo, vns) end -function set_retained_vns_del!(vi::ThreadSafeVarInfo, spl::Sampler) - return set_retained_vns_del!(vi.varinfo, spl) +function set_retained_vns_del!(vi::ThreadSafeVarInfo) + return set_retained_vns_del!(vi.varinfo) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) From 1c50d0cd3efe1e414ed594ec7794b21745990eac Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 23 Jan 2025 17:47:34 +0000 Subject: [PATCH 30/40] Remove some VarInfo getindex with samplers stuff --- src/abstract_varinfo.jl | 7 +++--- src/compiler.jl | 53 ++++++++++++++--------------------------- src/model.jl | 4 ++-- src/sampler.jl | 31 ++++++++++-------------- src/simple_varinfo.jl | 6 +---- src/varinfo.jl | 22 +++++------------ test/model.jl | 4 ++-- test/sampler.jl | 4 ++-- 8 files changed, 46 insertions(+), 85 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 26c4268d8..9f744984b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -159,7 +159,6 @@ The default implementation is to call [`values_as`](@ref) with `Vector` as the t See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) """ Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) -Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] """ getindex_internal(vi::AbstractVarInfo, vn::VarName) @@ -352,13 +351,13 @@ Determine the default `eltype` of the values returned by `vi[spl]`. This method is considered legacy, and is likely to be deprecated in the future. """ -function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) - T = Base.promote_op(getindex, typeof(vi), typeof(spl)) +function Base.eltype(vi::AbstractVarInfo) + T = Base.promote_op(getindex, typeof(vi), Colon) if T === Union{} # In this case `getindex(vi, spl)` errors # Let us throw a more descriptive error message # Ref https://github.com/TuringLang/Turing.jl/issues/2151 - return eltype(vi[spl]) + return eltype(vi[:]) end return eltype(T) end diff --git a/src/compiler.jl b/src/compiler.jl index c67da6f95..551f49266 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3,7 +3,7 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) """ need_concretize(expr) -Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or +Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or requires a dynamic optic. # Examples @@ -731,18 +731,14 @@ function warn_empty(body) end """ - matchingvalue(sampler, vi, value) - matchingvalue(context::AbstractContext, vi, value) + matchingvalue(vi, value) -Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. - -For a `context` that is _not_ a `SamplingContext`, we fall back to -`matchingvalue(SampleFromPrior(), vi, value)`. +Convert the `value` to the correct type for the `vi` object. """ -function matchingvalue(sampler, vi, value) +function matchingvalue(vi, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(sampler, vi, T), value) + _value = convert(get_matching_type(vi, T), value) if _value === value return deepcopy(_value) else @@ -753,24 +749,11 @@ function matchingvalue(sampler, vi, value) end end # If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`. -function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) - return get_matching_type(sampler, vi, value) -end -function matchingvalue(sampler::AbstractSampler, vi, value::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(sampler, vi, T)}() -end - -function matchingvalue(context::AbstractContext, vi, value) - return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) -end -function matchingvalue(::IsLeaf, context::AbstractContext, vi, value) - return matchingvalue(SampleFromPrior(), vi, value) -end -function matchingvalue(::IsParent, context::AbstractContext, vi, value) - return matchingvalue(childcontext(context), vi, value) +function matchingvalue(vi, value::FloatOrArrayType) + return get_matching_type(vi, value) end -function matchingvalue(context::SamplingContext, vi, value) - return matchingvalue(context.sampler, vi, value) +function matchingvalue(vi, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(vi, T)}() end """ @@ -781,16 +764,16 @@ Get the specialized version of type `T` for sampler `spl`. For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is `eltype(vi[spl])`. """ -get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi, spl))} +get_matching_type(_, ::Type{T}) where {T} = T +function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(eltype(vi))} end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi, spl)) +function get_matching_type(vi, ::Type{<:AbstractFloat}) + return float_type_with_fallback(eltype(vi)) end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(spl, vi, T),N} +function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(vi, T),N} end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(spl, vi, T)} +function get_matching_type(vi, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(vi, T)} end diff --git a/src/model.jl b/src/model.jl index 462db7397..3601d77fd 100644 --- a/src/model.jl +++ b/src/model.jl @@ -948,9 +948,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(context_new, varinfo, model.args.$var)...) + :($matchingvalue(varinfo, model.args.$var)...) else - :($matchingvalue(context_new, varinfo, model.args.$var)) + :($matchingvalue(varinfo, model.args.$var)) end for var in argnames ] diff --git a/src/sampler.jl b/src/sampler.jl index 974828e8b..56cd8404e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -118,7 +118,7 @@ function AbstractMCMC.step( # Update the parameters if provided. if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, spl, model) + vi = initialize_parameters!!(vi, initial_params, model) # Update joint log probability. # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 @@ -156,9 +156,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function set_values!!( - varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler -) +function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector) throw( ArgumentError( "`initial_params` must be a vector of type `Union{Real,Missing}`. " * @@ -168,11 +166,9 @@ function set_values!!( end function set_values!!( - varinfo::AbstractVarInfo, - initial_params::AbstractVector{<:Union{Real,Missing}}, - spl::AbstractSampler, + varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} ) - flattened_param_vals = varinfo[spl] + flattened_param_vals = varinfo[:] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( "Provided initial value size ($(length(initial_params))) doesn't match " * @@ -189,12 +185,11 @@ function set_values!!( end # Update in `varinfo`. - return setindex!!(varinfo, flattened_param_vals, spl) + setall!(varinfo, flattened_param_vals) + return varinfo end -function set_values!!( - varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler -) +function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple) vars_in_varinfo = keys(varinfo) for v in keys(initial_params) vn = VarName{v}() @@ -219,23 +214,21 @@ function set_values!!( ) end -function initialize_parameters!!( - vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model -) +function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) @debug "Using passed-in initial variable values" initial_params # `link` the varinfo if needed. - linked = islinked(vi, spl) + linked = islinked(vi) if linked - vi = invlink!!(vi, spl, model) + vi = invlink!!(vi, model) end # Set the values in `vi`. - vi = set_values!!(vi, initial_params, spl) + vi = set_values!!(vi, initial_params) # `invlink` if needed. if linked - vi = link!!(vi, spl, model) + vi = link!!(vi, model) end return vi diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 57b167077..b45d0dcc8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -428,11 +428,7 @@ const SimpleOrThreadSafeSimple{T,V,C} = Union{ } # Necessary for `matchingvalue` to work properly. -function Base.eltype( - vi::SimpleOrThreadSafeSimple{<:Any,V}, spl::Union{AbstractSampler,SampleFromPrior} -) where {V} - return V -end +Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) diff --git a/src/varinfo.jl b/src/varinfo.jl index 09f5960c1..bbcb638bb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1677,6 +1677,8 @@ function _invlink_metadata!!( return metadata end +# TODO(mhauru) We have varying conventions below for what to do if some variables are linked +# and others are not. """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -1703,6 +1705,10 @@ end return Expr(:||, false, out...) end +function islinked(vi::VarInfo) + return any(istrans(vi, vn) for vn in keys(vi)) +end + function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) end @@ -1788,22 +1794,6 @@ function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) return recombine(dist, vals_linked, length(vns)) end -""" - getindex(vi::VarInfo, spl::Union{SampleFromPrior, Sampler}) - -Return the current value(s) of the random variables sampled by `spl` in `vi`. - -The value(s) may or may not be transformed to Euclidean space. -""" -getindex(vi::UntypedVarInfo, spl::Sampler) = - copy(getindex(vi.metadata.vals, _getranges(vi, spl))) -getindex(vi::VarInfo, spl::Sampler) = copy(getindex_internal(vi, _getranges(vi, spl))) -function getindex(vi::TypedVarInfo, spl::Sampler) - # Gets the ranges as a NamedTuple - ranges = _getranges(vi, spl) - # Calling getfield(ranges, f) gives all the indices in `vals` of the `vn`s with symbol `f` sampled by `spl` in `vi` - return reduce(vcat, _getindex(vi.metadata, ranges)) -end # Recursively builds a tuple of the `vals` of all the symbols @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} expr = Expr(:tuple) diff --git a/test/model.jl b/test/model.jl index e91de4bd2..256ada0ad 100644 --- a/test/model.jl +++ b/test/model.jl @@ -230,8 +230,8 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 # Sample with large variations. - r_raw = randn(length(vi[spl])) * 10 - vi[spl] = r_raw + r_raw = randn(length(vi[:])) * 10 + DynamicPPL.setall!(vi, r_raw) @test vi[@varname(m)] == r_raw[1] @test vi[@varname(x)] != r_raw[2] model(vi) diff --git a/test/sampler.jl b/test/sampler.jl index 3b5424671..50111b1fd 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -196,11 +196,11 @@ vi = VarInfo(model) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model + vi, [initial_z, initial_x], model ) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model + vi, (X=initial_x, Z=initial_z), model ) end end From b919fe4851a469869f6d5e0872715060d78d5281 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 24 Jan 2025 15:04:27 +0000 Subject: [PATCH 31/40] Remove some index setting with samplers --- src/varinfo.jl | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index bbcb638bb..54d846f07 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1096,12 +1096,6 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] return expr end -# FIXME(torfjelde): Don't use `_getvns`. -Base.keys(vi::UntypedVarInfo, spl::AbstractSampler) = _getvns(vi, spl) -function Base.keys(vi::TypedVarInfo, spl::AbstractSampler) - return mapreduce(values, vcat, _getvns(vi, spl)) -end - """ setgid!(vi::VarInfo, gid::Selector, vn::VarName) @@ -1825,35 +1819,7 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::VarInfo, val, spl::SampleFromPrior) = setall!(vi, val) setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) -function setindex!(vi::TypedVarInfo, val, spl::Sampler) - # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` - ranges = _getranges(vi, spl) - _setindex!(vi.metadata, val, ranges) - return nothing -end - -function BangBang.setindex!!(vi::VarInfo, val, spl::AbstractSampler) - setindex!(vi, val, spl) - return vi -end - -# Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. -@generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} - expr = Expr(:block) - offset = :(0) - for f in names - f_vals = :(metadata.$f.vals) - f_range = :(ranges.$f) - start = :($offset + 1) - len = :(length($f_range)) - finish = :($offset + $len) - push!(expr.args, :(@views $f_vals[$f_range] .= val[($start):($finish)])) - offset = :($offset + $len) - end - return expr -end @inline function findvns(vi, f_vns) if length(f_vns) == 0 From fc0e0642d6fc5c90d4fb2a4262c91ba0f6cb9717 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 24 Jan 2025 15:09:01 +0000 Subject: [PATCH 32/40] Remove more sampler indexing --- src/simple_varinfo.jl | 6 +----- src/threadsafe.jl | 13 +------------ src/varinfo.jl | 23 +++++------------------ 3 files changed, 7 insertions(+), 35 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b45d0dcc8..db0609592 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -342,10 +342,6 @@ function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) return Accessors.@set vi.values = set!!(vi.values, vn, val) end -function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) - return unflatten(vi, spl, val) -end - # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with # same symbol and same type of, say, `IndexLens`, for improved `.~` performance. function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) @@ -558,7 +554,7 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi) +islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 69be5dcb1..0ca5d2e56 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -79,7 +79,7 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) +islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo) function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) @@ -138,17 +138,6 @@ end function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) return getindex(vi.varinfo, vns, dist) end -getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) - -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) -end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) diff --git a/src/varinfo.jl b/src/varinfo.jl index 54d846f07..710d02c5d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1674,31 +1674,18 @@ end # TODO(mhauru) We have varying conventions below for what to do if some variables are linked # and others are not. """ - islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) + islinked(vi::VarInfo) -Check whether `vi` is in the transformed space for a particular sampler `spl`. +Check whether `vi` is in the transformed space. Turing's Hamiltonian samplers use the `link` and `invlink` functions from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable (for example, one bounded to the space `[0, 1]`) from its constrained space to the set of real numbers. `islinked` checks if the number is in the constrained space or the real space. -""" -function islinked(vi::UntypedVarInfo, spl::Union{Sampler,SampleFromPrior}) - vns = _getvns(vi, spl) - return istrans(vi, vns[1]) -end -function islinked(vi::TypedVarInfo, spl::Union{Sampler,SampleFromPrior}) - vns = _getvns(vi, spl) - return _islinked(vi, vns) -end -@generated function _islinked(vi, vns::NamedTuple{names}) where {names} - out = [] - for f in names - push!(out, :(isempty(vns.$f) ? false : istrans(vi, vns.$f[1]))) - end - return Expr(:||, false, out...) -end +If some but only some of the variables in `vi` are linked, this function will return `true`. +This behavior will likely change in the future. +""" function islinked(vi::VarInfo) return any(istrans(vi, vn) for vn in keys(vi)) end From 5fbe0160a739bafff35f2846717afea4d2e87feb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 12:37:25 +0000 Subject: [PATCH 33/40] Remove unflatten with samplers --- src/abstract_varinfo.jl | 18 ++---------------- src/logdensityfunction.jl | 2 +- src/simple_varinfo.jl | 1 - src/threadsafe.jl | 7 ------- src/utils.jl | 6 +++--- src/varinfo.jl | 38 +++++++++++++------------------------- src/varnamedvector.jl | 15 --------------- 7 files changed, 19 insertions(+), 68 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 9f744984b..019cf885c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -719,25 +719,11 @@ end # Utilities """ - unflatten(vi::AbstractVarInfo[, context::AbstractContext], x::AbstractVector) + unflatten(vi::AbstractVarInfo, x::AbstractVector) Return a new instance of `vi` with the values of `x` assigned to the variables. - -If `context` is provided, `x` is assumed to be realizations only for variables not -filtered out by `context`. """ -function unflatten(varinfo::AbstractVarInfo, context::AbstractContext, θ) - if hassampler(context) - unflatten(getsampler(context), varinfo, context, θ) - else - DynamicPPL.unflatten(varinfo, θ) - end -end - -# TODO: deprecate this once `sampler` is no longer the main way of filtering out variables. -function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::AbstractContext, θ) - return unflatten(varinfo, sampler, θ) -end +function unflatten end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 214369ab0..5c1f28f80 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -136,7 +136,7 @@ getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) context = getcontext(f) - vi_new = unflatten(f.varinfo, context, θ) + vi_new = unflatten(f.varinfo, θ) return getlogp(last(evaluate!!(f.model, vi_new, context))) end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index db0609592..07296c3f7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -258,7 +258,6 @@ function typed_simple_varinfo(model::Model) return last(evaluate!!(model, varinfo, SamplingContext())) end -unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x) function unflatten(svi::SimpleVarInfo, x::AbstractVector) logp = getlogp(svi) vals = unflatten(svi.values, x) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 0ca5d2e56..4367ff06d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -173,13 +173,9 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) end -# Transformations. function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) end -function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) -end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) @@ -189,9 +185,6 @@ getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.var function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) end -function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, spl, x) -end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) diff --git a/src/utils.jl b/src/utils.jl index 2539b7179..c4ef0ab09 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -710,6 +710,9 @@ function unflatten(original, x::AbstractVector) return unflatten(v, @view(x[start_idx:end_idx])) end end +function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} + return NamedTuple{names}(unflatten(values(original), x)) +end unflatten(::Real, x::Real) = x unflatten(::Real, x::AbstractVector) = only(x) @@ -728,9 +731,6 @@ function unflatten(original::Tuple, x::AbstractVector) return unflatten(v, @view(x[start_idx:end_idx])) end end -function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} - return NamedTuple{names}(unflatten(values(original), x)) -end function unflatten(original::AbstractDict, x::AbstractVector) D = ConstructionBase.constructorof(typeof(original)) return D(zip(keys(original), unflatten(collect(values(original)), x))) diff --git a/src/varinfo.jl b/src/varinfo.jl index 710d02c5d..d960c21b6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -217,45 +217,33 @@ vector_length(varinfo::VarInfo) = length(varinfo.metadata) vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) -unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) - -# TODO: deprecate. -function unflatten(vi::VarInfo, spl::AbstractSampler, x::AbstractVector) - md = unflatten(vi.metadata, spl, x) - return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) -end - -# The Val(getspace(spl)) is used to dispatch into the below generated function. -function unflatten(metadata::NamedTuple, spl::AbstractSampler, x::AbstractVector) - return unflatten(metadata, Val(getspace(spl)), x) +function unflatten(vi::VarInfo, x::AbstractVector) + md = unflatten_metadata(vi.metadata, x) + return VarInfo(md, Ref(getlogp(vi)), Ref(get_num_produce(vi))) end -@generated function unflatten( - metadata::NamedTuple{names}, ::Val{space}, x -) where {names,space} +# This must not be called `unflatten` because of the `unflatten` methods in utils.jl. +@generated function unflatten_metadata( + metadata::NamedTuple{names}, x::AbstractVector +) where {names} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) - if inspace(f, space) || length(space) == 0 - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - else - push!(exprs, :($f = $mdf)) - end + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) end # For Metadata unflatten and replace_values are the same. For VarNamedVector they are not. -function unflatten(md::Metadata, x::AbstractVector) +function unflatten_metadata(md::Metadata, x::AbstractVector) return replace_values(md, x) end -function unflatten(md::Metadata, spl::AbstractSampler, x::AbstractVector) - return replace_values(md, spl, x) -end + +unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) # without AbstractSampler function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 7da126321..3b3f0ce42 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -510,12 +510,6 @@ function getindex_internal(vnv::VarNamedVector, ::Colon) end end -# TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs -# sampler. -function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler) - throw(ErrorException("Cannot index a VarNamedVector with a sampler.")) -end - function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) if haskey(vnv, vn) return update!(vnv, val, vn) @@ -1077,15 +1071,6 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) ) end -# TODO(mhauru) To be removed once the old Gibbs sampler is removed. -function unflatten(vnv::VarNamedVector, spl::AbstractSampler, vals::AbstractVector) - if length(getspace(spl)) > 0 - msg = "Selecting values in a VarNamedVector with a space is not supported." - throw(ArgumentError(msg)) - end - return unflatten(vnv, vals) -end - function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Return early if possible. isempty(left_vnv) && return deepcopy(right_vnv) From cb5c79d408b84e644476dadd16897251d7feff5b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 14:09:15 +0000 Subject: [PATCH 34/40] Clean up some setindex stuff --- src/abstract_varinfo.jl | 5 ++--- src/varinfo.jl | 9 --------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 019cf885c..6ad27acb5 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -149,7 +149,6 @@ If `dist` is specified, the value(s) will be massaged into the representation ex """ getindex(vi::AbstractVarInfo, ::Colon) - getindex(vi::AbstractVarInfo, ::AbstractSampler) Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) distribution(s) as a flattened `Vector`. @@ -340,9 +339,9 @@ julia> values_as(vi, Vector) function values_as end """ - eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior} + eltype(vi::AbstractVarInfo) -Determine the default `eltype` of the values returned by `vi[spl]`. +Return the `eltype` of the values returned by `vi[:]`. !!! warning This should generally not be called explicitly, as it's only used in diff --git a/src/varinfo.jl b/src/varinfo.jl index d960c21b6..f3b629bd0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1787,15 +1787,6 @@ function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) return vi end -""" - setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) - -Set the current value(s) of the random variables sampled by `spl` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) - @inline function findvns(vi, f_vns) if length(f_vns) == 0 throw("Unidentified error, please report this error in an issue.") From 414f58efcc35c7709a4cd751f22515665f3c8f6b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 14:46:19 +0000 Subject: [PATCH 35/40] Remove a bunch of varinfo.jl internal functions that used samplers/space, update HISTORY.md --- HISTORY.md | 6 ++ src/varinfo.jl | 170 ++++--------------------------------------------- 2 files changed, 17 insertions(+), 159 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 03c564b64..bed1a50a0 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -10,6 +10,12 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `link` and `invlink`, and their `!!` versions, no longer accept a sampler as an argument to specify which variables to (inv)link. The `link(varinfo, model)` methods remain in place, and as a new addition one can give a `Tuple` of `VarName`s to (inv)link only select variables, as in `link(varinfo, varname_tuple, model)`. - `set_retained_vns_del_by_spl!` has been replaced by `set_retained_vns_del!` which applies to all variables. + - `getindex`, `setindex!`, and `setindex!!` no longer accept samplers as arguments + - `unflatten` no longer accepts a sampler as an argument + - `eltype(::VarInfo)` no longer accepts a sampler as an argument + - `keys(::VarInfo)` no longer accepts a sampler as an argument + - `push!!` and `push!` no longer accept samplers or `Selector`s as arguments + - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. ### Reverse prefixing order diff --git a/src/varinfo.jl b/src/varinfo.jl index f3b629bd0..0cba85700 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -111,10 +111,11 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. -transformation(vi::VarInfo) = DynamicTransformation() +transformation(::VarInfo) = DynamicTransformation() -function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) - md = replace_values(old_vi.metadata, Val(getspace(spl)), x) +# TODO(mhauru) Isn't this the same as unflatten and/or replace_values? +function VarInfo(old_vi::VarInfo, x::AbstractVector) + md = replace_values(old_vi.metadata, x) return VarInfo( md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) ) @@ -250,8 +251,6 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext return VarInfo(rng, model, SampleFromPrior(), context) end -# TODO: Remove `space` argument when no longer needed. Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/573 -replace_values(metadata::Metadata, space, x) = replace_values(metadata, x) function replace_values(metadata::Metadata, x) return Metadata( metadata.idcs, @@ -265,20 +264,14 @@ function replace_values(metadata::Metadata, x) ) end -@generated function replace_values( - metadata::NamedTuple{names}, ::Val{space}, x -) where {names,space} +@generated function replace_values(metadata::NamedTuple{names}, x) where {names} exprs = [] offset = :(0) for f in names mdf = :(metadata.$f) - if inspace(f, space) || length(space) == 0 - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - else - push!(exprs, :($f = $mdf)) - end + len = :(sum(length, $mdf.ranges)) + push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) + offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) @@ -774,7 +767,7 @@ settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) """ syms(vi::VarInfo) -Returns a tuple of the unique symbols of random variables sampled in `vi`. +Returns a tuple of the unique symbols of random variables in `vi`. """ syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) @@ -782,16 +775,6 @@ syms(vi::TypedVarInfo) = keys(vi.metadata) _getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) _getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) -# Get all indices of variables belonging to SampleFromPrior: -# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to -# the SampleFromPrior sampler -@inline function _getidcs(vi::UntypedVarInfo, ::SampleFromPrior) - return filter(i -> isempty(vi.metadata.gids[i]), 1:length(vi.metadata.gids)) -end -# Get a NamedTuple of all the indices belonging to SampleFromPrior, one for each symbol -@inline function _getidcs(vi::TypedVarInfo, ::SampleFromPrior) - return _getidcs(vi.metadata) -end @generated function _getidcs(metadata::NamedTuple{names}) where {names} exprs = [] for f in names @@ -801,93 +784,15 @@ end return :($(exprs...),) end -# Get all indices of variables belonging to a given sampler -@inline function _getidcs(vi::VarInfo, spl::Sampler) - # NOTE: 0b00 is the sanity flag for - # |\____ getidcs (mask = 0b10) - # \_____ getranges (mask = 0b01) - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - # Checks if cache is valid, i.e. no new pushes were made, to return the cached idcs - # Otherwise, it recomputes the idcs and caches it - #if haskey(spl.info, :idcs) && (spl.info[:cache_updated] & CACHEIDCS) > 0 - # spl.info[:idcs] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS - idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) - #spl.info[:idcs] = idcs - #end - return idcs -end -@inline _getidcs(vi::UntypedVarInfo, s::Selector, space) = findinds(vi.metadata, s, space) -@inline _getidcs(vi::TypedVarInfo, s::Selector, space) = _getidcs(vi.metadata, s, space) -# Get a NamedTuple for all the indices belonging to a given selector for each symbol -@generated function _getidcs( - metadata::NamedTuple{names}, s::Selector, ::Val{space} -) where {names,space} - exprs = [] - # Iterate through each varname in metadata. - for f in names - # If the varname is in the sampler space - # or the sample space is empty (all variables) - # then return the indices for that variable. - if inspace(f, space) || length(space) == 0 - push!(exprs, :($f = findinds(metadata.$f, s, Val($space)))) - end - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end -@inline function findinds(f_meta::Metadata, s, ::Val{space}) where {space} - # Get all the idcs of the vns in `space` and that belong to the selector `s` - return filter( - (i) -> - (s in f_meta.gids[i] || isempty(f_meta.gids[i])) && - (isempty(space) || inspace(f_meta.vns[i], space)), - 1:length(f_meta.gids), - ) -end @inline function findinds(f_meta::Metadata) # Get all the idcs of the vns return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) end -function findinds(vnv::VarNamedVector, ::Selector, ::Val{space}) where {space} - # New Metadata objects are created with an empty list of gids, which is intrepreted as - # all Selectors applying to all variables. We assume the same behavior for - # VarNamedVector, and thus ignore the Selector argument. - if space !== () - msg = "VarNamedVector does not support selecting variables based on samplers" - throw(ErrorException(msg)) - else - return findinds(vnv) - end -end - function findinds(vnv::VarNamedVector) return 1:length(vnv.varnames) end -# Get all vns of variables belonging to spl -_getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) - return _getvns(vi, Selector(), Val(())) -end -function _getvns(vi::UntypedVarInfo, s::Selector, space) - return view(vi.metadata.vns, _getidcs(vi, s, space)) -end -function _getvns(vi::TypedVarInfo, s::Selector, space) - return _getvns(vi.metadata, _getidcs(vi, s, space)) -end -# Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol -@generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = Base.keys(metadata.$f)[idcs.$f])) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - """ all_varnames_grouped_by_symbol(vi::TypedVarInfo) @@ -904,47 +809,6 @@ all_varnames_grouped_by_symbol(vi::TypedVarInfo) = return expr end -# Get the index (in vals) ranges of all the vns of variables belonging to spl -@inline function _getranges(vi::VarInfo, spl::Sampler) - ## Uncomment the spl.info stuff when it is concretely typed, not Dict{Symbol, Any} - #if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end - #if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0 - # spl.info[:ranges] - #else - #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES - ranges = _getranges(vi, spl.selector, Val(getspace(spl))) - #spl.info[:ranges] = ranges - return ranges - #end -end -# Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::VarInfo, s::Selector, space) - return _getranges(vi, _getidcs(vi, s, space)) -end -@inline function _getranges(vi::VarInfo, idcs::Vector{Int}) - return mapreduce(i -> vi.metadata.ranges[i], vcat, idcs; init=Int[]) -end -@inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs) - -@generated function _getranges(metadata::NamedTuple, idcs::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findranges(metadata.$f.ranges, idcs.$f))) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -@inline function findranges(f_ranges, f_idcs) - # Old implementation was using `mapreduce` but turned out - # to be type-unstable. - results = Int[] - for i in f_idcs - append!(results, f_ranges[i]) - end - return results -end - # TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called # set_flag!!. """ @@ -1173,7 +1037,6 @@ function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, mode return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) end -# X -> R for all variables associated with given sampler function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -1279,7 +1142,6 @@ function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, m return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) end -# R -> X for all variables associated with given sampler function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1376,16 +1238,6 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) return vi end -# HACK: We need `SampleFromPrior` to result in ALL values which are in need -# of a transformation to be transformed. `_getvns` will by default return -# an empty iterable for `SampleFromPrior`, so we need to override it here. -# This is quite hacky, but seems safer than changing the behavior of `_getvns`. -_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) -_getvns_link(varinfo::VarInfo, spl::SampleFromPrior) = nothing -function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) - return map(Returns(nothing), varinfo.metadata) -end - function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1599,7 +1451,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. + # supposed to touch this `vn`. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] @@ -1799,7 +1651,7 @@ Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) """ haskey(vi::VarInfo, vn::VarName) -Check whether `vn` has been sampled in `vi`. +Check whether `vn` has a value in `vi`. """ Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) function Base.haskey(vi::TypedVarInfo, vn::VarName) From 800b91aaa7c9d32f717ab98426372d82df89f1aa Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 14:50:32 +0000 Subject: [PATCH 36/40] Fix HISTORY.md --- HISTORY.md | 1 - 1 file changed, 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index bed1a50a0..6b7247c8d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -14,7 +14,6 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `unflatten` no longer accepts a sampler as an argument - `eltype(::VarInfo)` no longer accepts a sampler as an argument - `keys(::VarInfo)` no longer accepts a sampler as an argument - - `push!!` and `push!` no longer accept samplers or `Selector`s as arguments - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. ### Reverse prefixing order From e65777e9bc2c0c32e5ce962e9a6d54ece87ed071 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 15:18:42 +0000 Subject: [PATCH 37/40] Miscalleanous small fixes --- src/abstract_varinfo.jl | 2 +- src/compiler.jl | 14 ++++++++------ src/utils.jl | 10 +++++----- src/varinfo.jl | 10 +++++++--- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 6ad27acb5..4e9e5c554 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -353,7 +353,7 @@ Return the `eltype` of the values returned by `vi[:]`. function Base.eltype(vi::AbstractVarInfo) T = Base.promote_op(getindex, typeof(vi), Colon) if T === Union{} - # In this case `getindex(vi, spl)` errors + # In this case `getindex(vi, :)` errors # Let us throw a more descriptive error message # Ref https://github.com/TuringLang/Turing.jl/issues/2151 return eltype(vi[:]) diff --git a/src/compiler.jl b/src/compiler.jl index 551f49266..ddf1b6787 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -730,6 +730,8 @@ function warn_empty(body) return nothing end +# TODO(mhauru) matchinvalue has methods that can accept both types and values. Why? +# TODO(mhauru) This function needs a more comprehensive docstring. """ matchingvalue(vi, value) @@ -739,6 +741,8 @@ function matchingvalue(vi, value) T = typeof(value) if hasmissing(T) _value = convert(get_matching_type(vi, T), value) + # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we + # are happy to return `value` as-is? if _value === value return deepcopy(_value) else @@ -748,7 +752,7 @@ function matchingvalue(vi, value) return value end end -# If we hit `Type` or `TypeWrap`, we immediately jump to `get_matching_type`. + function matchingvalue(vi, value::FloatOrArrayType) return get_matching_type(vi, value) end @@ -756,13 +760,11 @@ function matchingvalue(vi, ::TypeWrap{T}) where {T} return TypeWrap{get_matching_type(vi, T)}() end +# TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(spl::AbstractSampler, vi, ::TypeWrap{T}) where {T} - -Get the specialized version of type `T` for sampler `spl`. + get_matching_type(vi, ::TypeWrap{T}) where {T} -For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is -`eltype(vi[spl])`. +Get the specialized version of type `T` for `vi`. """ get_matching_type(_, ::Type{T}) where {T} = T function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) diff --git a/src/utils.jl b/src/utils.jl index c4ef0ab09..d64f6dc66 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -710,9 +710,6 @@ function unflatten(original, x::AbstractVector) return unflatten(v, @view(x[start_idx:end_idx])) end end -function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} - return NamedTuple{names}(unflatten(values(original), x)) -end unflatten(::Real, x::Real) = x unflatten(::Real, x::AbstractVector) = only(x) @@ -731,6 +728,9 @@ function unflatten(original::Tuple, x::AbstractVector) return unflatten(v, @view(x[start_idx:end_idx])) end end +function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} + return NamedTuple{names}(unflatten(values(original), x)) +end function unflatten(original::AbstractDict, x::AbstractVector) D = ConstructionBase.constructorof(typeof(original)) return D(zip(keys(original), unflatten(collect(values(original)), x))) @@ -942,9 +942,9 @@ function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) end """ - float_type_with_fallback(x) + float_type_with_fallback(T::DataType) -Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`. +Return `float(T)` if possible; otherwise return `float(Real)`. """ float_type_with_fallback(::Type) = float(Real) float_type_with_fallback(::Type{Union{}}) = float(Real) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0cba85700..01b9f9753 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -223,7 +223,8 @@ function unflatten(vi::VarInfo, x::AbstractVector) return VarInfo(md, Ref(getlogp(vi)), Ref(get_num_produce(vi))) end -# This must not be called `unflatten` because of the `unflatten` methods in utils.jl. +# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in +# utils.jl. @generated function unflatten_metadata( metadata::NamedTuple{names}, x::AbstractVector ) where {names} @@ -1511,8 +1512,11 @@ function _invlink_metadata!!( return metadata end -# TODO(mhauru) We have varying conventions below for what to do if some variables are linked -# and others are not. +# TODO(mhauru) The treatment of the case when some variables are linked and others are not +# should be revised. It used to be the case that for UntypedVarInfo `islinked` returned +# whether the first variable was linked. For TypedVarInfo we did an OR over the first +# variables under each symbol. We now more consistently use OR, but I'm not convinced this +# is really the right thing to do. """ islinked(vi::VarInfo) From 8fcc289af4ee9c6c309e0aeb3025b626216fd66d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 15:45:18 +0000 Subject: [PATCH 38/40] Fix a bug in VarInfo constructor --- src/varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 01b9f9753..8f7f7b6c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -220,7 +220,9 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - return VarInfo(md, Ref(getlogp(vi)), Ref(get_num_produce(vi))) + # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases + # where e.g. x is a type gradient of some AD backend. + return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in From c59aafe69d2aba01776df9ad9b9fc850db6c627f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Jan 2025 15:49:59 +0000 Subject: [PATCH 39/40] Fix getparams(::LogDensityFunction) --- src/logdensityfunction.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 5c1f28f80..29f591cc3 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -121,17 +121,12 @@ end getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) -_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx) -_get_indexer(ctx::SamplingContext) = ctx.sampler -_get_indexer(::IsParent, ctx::AbstractContext) = _get_indexer(childcontext(ctx)) -_get_indexer(::IsLeaf, ctx::AbstractContext) = Colon() - """ getparams(f::LogDensityFunction) Return the parameters of the wrapped varinfo as a vector. """ -getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))] +getparams(f::LogDensityFunction) = f.varinfo[:] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) From 934fb7998d3accbc18cf6882408415c3c32d2ec7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 3 Feb 2025 15:28:02 +0000 Subject: [PATCH 40/40] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index ddf1b6787..8743641af 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -730,7 +730,7 @@ function warn_empty(body) return nothing end -# TODO(mhauru) matchinvalue has methods that can accept both types and values. Why? +# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ matchingvalue(vi, value)