From eb2836dd08289600831c5b3062e716122cbdfaa2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 14 Oct 2025 14:27:51 +0100 Subject: [PATCH 01/18] Sketching VarNamedTuple and its VarInfo --- src/DynamicPPL.jl | 2 ++ src/varinfo.jl | 54 +++++++++++++++++++++++++++++++++++++++++++++++ src/varname.jl | 21 ++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f5bd33d6d..4e2f43445 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -178,6 +178,8 @@ include("contexts/prefix.jl") include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") include("varname.jl") +include("varnamedtuple.jl") +using .VarNamedTuples: VarNamedTuple include("distribution_wrappers.jl") include("submodel.jl") include("varnamedvector.jl") diff --git a/src/varinfo.jl b/src/varinfo.jl index 734bf3db5..0fb77aa7a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +const TupleVarInfo = VarInfo{<:VarNamedTuple} function Base.:(==)(vi1::VarInfo, vi2::VarInfo) return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) @@ -356,6 +357,28 @@ function typed_vector_varinfo( return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end +function make_leaf_metadata((r, dist), optic) + md = Metadata() + vn = VarName{:_}(optic) + push!(md, vn, r, dist) + return md +end + +function tuple_varinfo() + metadata = VarNamedTuple((;), make_leaf_metadata) + return VarInfo(metadata, copy(default_accumulators())) +end +function tuple_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return last(init!!(rng, model, tuple_varinfo(), init_strategy)) +end +function tuple_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return tuple_varinfo(Random.default_rng(), model, init_strategy) +end + """ vector_length(varinfo::VarInfo) @@ -639,6 +662,9 @@ Return the metadata in `vi` that belongs to `vn`. """ getmetadata(vi::VarInfo, vn::VarName) = vi.metadata getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) +function getmetadata(vi::TupleVarInfo, vn::VarName) + return getindex(vi.metadata, remove_trailing_index(vn)) +end """ getidx(vi::VarInfo, vn::VarName) @@ -744,6 +770,10 @@ end Return the distribution from which `vn` was sampled in `vi`. """ getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) +function getdist(vi::TupleVarInfo, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return getdist(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] # TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone. function getdist(::VarNamedVector, ::VarName) @@ -782,6 +812,10 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) +function setval!(vi::TupleVarInfo, val, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return setval!(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end function setval!(md::Metadata, val::AbstractVector, vn::VarName) return md.vals[getrange(md, vn)] = val end @@ -1579,6 +1613,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) end return any(md_haskey) end +Base.haskey(vi::TupleVarInfo, vn::VarName) = haskey(vi.metadata, vn) function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) lines = Tuple{String,Any}[ @@ -1673,6 +1708,25 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) return vi end +function BangBang.push!!(vi::TupleVarInfo, vn::VarName, r, dist::Distribution) + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TupleVarInfo with dist=$dist" + return VarInfo(setindex!!(vi.metadata, (r, dist), vn), vi.accs) +end + +# TODO(mhauru) Implement properly +function is_transformed(vi::TupleVarInfo, vn::VarName) + return false +end + +function getindex(vi::TupleVarInfo, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return getindex(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end +function getindex_internal(vi::TupleVarInfo, vn::VarName) + main_vn, optic = split_trailing_index(vn) + return getindex_internal(getindex(vi.metadata, main_vn), VarName{:_}(optic)) +end + function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) push!(getmetadata(vi, vn), vn, val, args...) return vi diff --git a/src/varname.jl b/src/varname.jl index 3eb1f2460..687427f6e 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -41,3 +41,24 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} + return if Optic === typeof(identity) + vn + elseif Optic isa IndexLens + VarName{sym}() + else + prefix(remove_trailing_index(unprefix(vn, sym)), sym) + end +end + +function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} + return if Optic === typeof(identity) + (vn, identity) + elseif Optic isa IndexLens + (VarName{sym}(), Optic.index) + else + (prefix, index) = split_trailing_index(unprefix(vn, sym)) + (prefix(prefix, sym), index) + end +end From fd458e6ba2c06c4e0fb12eddb8f4d285cd2e6f48 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Oct 2025 16:33:18 +0100 Subject: [PATCH 02/18] Start a sketch doc for VarNamedTuple design --- docs/src/internals/varnamedtuple.md | 106 ++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 docs/src/internals/varnamedtuple.md diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md new file mode 100644 index 000000000..202028e55 --- /dev/null +++ b/docs/src/internals/varnamedtuple.md @@ -0,0 +1,106 @@ +# VarNamedTuple as the basis of VarInfo + +This document collects thoughts and ideas for how to unify our multitude of AbstractVarInfo types using a VarNamedTuple type. It may eventually turn into a draft design document, but for now it is more raw than that. + +## The current situation + +We currently have the following AbstractVarInfo types: + + - A: VarInfo with Metadata + - B: VarInfo with VarNamedTuple + - C: VarInfo with NamedTuple, with values being Metadata + - D: VarInfo with NamedTuple, with values being VarNamedTuple + - E: SimpleVarInfo with NamedTuples + - F: SimpleVarInfo with OrderedDict + +A and C are the classic ones, and the defaults. C wraps groups the Metadata objects by the lead Symbol of the VarName of a variable, e.g. `x` in `@varname(x.y[1].z)`, which allows different lead Symbols to have different element types and for the VarInfo to still be type stable. B and D were created to simplify A and C, give them a nicer interface, and make them deal better with changing variable sizes, but according to recent (Oct 2025) benchmarks are quite a lot slower, which needs work. + +E and F are entirely distinct in implementation from the others. E is simply a mapping from Symbols to values, with each VarName being converted to a single symbol, e.g. `Symbol("a[1]")`. F is a mapping from VarNames to values as an OrderedDict, with VarName as the key type. + +A-D carry within them values for variables, but also their bijectors/distributions, and store all values vectorised, using the bijectors to map to the original values. They also store for each variable a flag for whether the variable has been linked. E-F store only the raw values, and a global flag for the whole SimpleVarInfo for whether it's linked. The link transform itself is implicit. + +TODO: Write a better summary of pros and cons of each approach. + +## VarNamedTuple + +VarNamedTuple has been discussed as a possible data structure to generalise the structure used in VarInfo to achieve type stability, i.e. grouping VarNames by their lead Symbol. The same NamedTuple structure has been used elsewhere, too, e.g. in Turing.GibbsContext. The idea was to encapsulate this structure into its own type, reducing code duplication and making the design more robust and powerful. See https://github.com/TuringLang/DynamicPPL.jl/issues/900 for the discussion. + +An AbstractVarInfo type could be only one application of VarNamedTuple, but here I'll focus on it exclusively. If we can make VarNamedTuple work for an AbstractVarInfo, I bet we can make it work for other purposes (condition, fix, Gibbs) as well. + +Without going into full detail, here's @mhauru's current proposal for what it would look like. This proposal remains in constant flux as I develop the code. + +A VarNamedTuple is a mapping of VarNames to values. Values can be anything. In the case of using VarNamedTuple to implement an AbstractVarInfo, the values would be random samples for random variables. However, they could hold with them extra information. For instance, we might use a value that is a tuple of a vectorised value, a bijector, and a flag for whether the variable is linked. + +I sometimes shorten VarNamedTuple to VNT. + +Internally, a VarNamedTuple consists of nested NamedTuples. For instance, the mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as + +``` +(; x=1, y=(; z=2)) +``` + +(This is a slight simplification, really it would be nested VarNamedTuples rather than NamedTuples, but I omit this detail.) +This forms a tree, with each node being a NamedTuple, like so: + +``` + NT +x / \ y + 1 NT + \ z + 2 +``` + +Each `NT` marks a NamedTuple, and the labels on the edges its keys. Here the root node has the keys `x` and `y`. This is like with the type stable VarInfo in our current design, except with possibly more levels (our current one only has the root node). Each nested `PropertyLens`, i.e. each `.` in a VarName like `@varname(a.b.c.e)`, creates a new layer of the tree. + +For simplicity, at least for now, we ban any VarNames where an `IndexLens` precedes a `PropertyLens`. That is, we ban any VarNames like `@varname(a.b[1].c)`. Recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). Thus the only allowed VarName types are `@varname(a.b.c.d)` and `@varname(a.b.c.d[i,j,k])`. + +This means that we can add levels to the NamedTuple tree until all `PropertyLenses` have been covered. The leaves of the tree are then of two kinds: They are either the raw value itself if the last lens of the VarName is an `identity`, or otherwise they are something that can be indexed with an `IndexLens`, such as an `Array`. + +To get a value from a VarNamedTuple is very simple: For `getindex(vnt::VNT, vn::VarName{S})` (`S` being the lead Symbol) you recurse into `getindex(vnt[S], unprefix(vn, S))`. If the last lens of `vn` is an `IndexLens`, we assume that the leaf of the NamedTuple tree we've reached contains something that can be indexed with it. + +Setting values in a VNT is equally simple if there are no `IndexLenses`: For `setindex!!(vnt::VNT, value::Any, vn::VarName)` one simply finds the leaf of the `vnt` tree corresponding to `vn` and sets its value to `value`. + +The tricky part is what to do when setting values with `IndexLenses`. There are three possible situations. Say one calls `setindex!!(vnt, 3.0, @varname(a.b[3]))`. + + 1. If `getindex(vnt, @varname(a.b))` is already a vector of length at least 3, this is easy: Just set the third element. + 2. If `getindex(vnt, @varname(a.b))` is a vector of length less than 3, what should we do? Do we error? Do we extend that vector? + 3. If `getindex(vnt, @varname(a.b))` isn't even set, what do we do? Say for instance that `vnt` is currently empty. We should set `vnt` to be something like `(; a=(; b=x))`, where `x` is such that `x[3] = 3.0`, but what exactly should `x` be? Is it a dictionary? A vector of length 3? If the latter, what are `x[2]` and `x[1]`? Or should this `setindex!!` call simply error? + +A note at this point: VarNamedTuples must always use `setindex!!`, the `!!` version that may or may not operate in place. The NamedTuples can't be modified in place, but the values at the leaves may be. Always using a `!!` function makes type stability easier, and makes structures like the type unstable old VarInfo with Metadata unnecessary: Any value can be set into any VarNamedTuple. The type parameters of the VNT will simply expand as necessary. + +To solve the problem of points 2. and 3. above I propose expanding the definition of VNT a bit. This will also help make VNT more flexible, which may help performance or allow more use cases. The modification is this: + +Unlike I said above, let's say that VNT isn't just nested NamedTuples with some values at the leaves. Let's say it also has a field called `make_leaf`. `make_leaf(value, lens)` is a function that takes any value, and a lens that is either `identity` or an `IndexLens`, and returns the value wrapped in some suitable struct that can be stored in the leaf of the NamedTuple tree. The values should always be such that `make_leaf(value, lens)[lens] == value`. + +Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like + +``` + ==NT== + x / \ y +f(1, identity) NT + \ z + f(2, identity) +``` + +The above, first draft of VNT which did not include `make_leaf` is equivalent to the trivial choice `make_leaf(value, lens) = lens === identity ? value : error("Don't know how to deal IndexLenses")`. The problems 2. and 3. above are "solved" by making it `make_leaf`'s problem to figure out what to do. For instance, `make_leaf` can always return a `Dict` that maps lenses to values. This is probably slow, but works for any lens. Or it can initialise a vector type, that can grow as needed when indexed into. + +The idea would be to use `make_leaf` to try out different ways of implementing a VarInfo, find a good default, and ,if necessary, leave the option for power users to customise behaviour. The first ones to implement would be + + - `make_leaf` that returns a Metadata object. This would be a direct replacement for type stable VarInfo that uses Metadata, except now with more nested levels of NamedTuple. + - `make_leaf` that returns an `OrderedDict`. This would be a direct replacement for SimpleVarInfo with OrderedDict. + +You may ask, have we simple gone from too many VarInfo types to too many `make_leaf` functions. Yes we have. But hopefully we have gained something in the process: + + - The leaf types can be simpler. They do not need to deal with VarNames any more, they only need to deal with `identity` lenses and `IndexLenses`. + - All AbstactVarInfos are as type stable as their leaf types allow. There is no more notion of an untyped VarInfo being converted to a typed one. + - Type stability is maintained even with nested `PropertyLenses` like `@varname(a.b)`, which happens a lot with submodels. + - Many functions that are currently implemented individually for each AbstactVarInfo type would now have a single implementation for the VarNamedTuple-based AbstactVarInfo type, reducing code duplication. I would also hope to get ride of most of the generated functions for in `varinfo.jl`. + +My guess is that the eventual One AbstractVarInfo To Rule Them All would have a `make_leaf` function that stores the raw values when the lens is an `identity`, and uses a flexible Vector, a lot like VarNamedVector, when the lens is an IndexLens. However, I could be wrong on that being the best option. Implementing and benchmarking is the only way to know. + +I think the two big questions are: + + - Will we run into some big, unanticipated blockers when we start to implement this. + - Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up. + +I'll try to derisk these early on in this PR. From 297c07c585f40a3f4c03c050714b6733c34bf96f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Oct 2025 16:34:53 +0100 Subject: [PATCH 03/18] Better ASCII graphics --- docs/src/internals/varnamedtuple.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 202028e55..1cba2adfb 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -75,7 +75,7 @@ Unlike I said above, let's say that VNT isn't just nested NamedTuples with some Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like ``` - ==NT== + --NT-- x / \ y f(1, identity) NT \ z From 4cfd2a6f40fd37baafcb04b6f67db8d95ef65f4e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Oct 2025 10:26:34 +0100 Subject: [PATCH 04/18] Fix typos VarNamedTuple -> VarNamedVector Co-authored-by: Penelope Yong --- docs/src/internals/varnamedtuple.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 1cba2adfb..c52bed974 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -7,9 +7,9 @@ This document collects thoughts and ideas for how to unify our multitude of Abst We currently have the following AbstractVarInfo types: - A: VarInfo with Metadata - - B: VarInfo with VarNamedTuple + - B: VarInfo with VarNamedVector - C: VarInfo with NamedTuple, with values being Metadata - - D: VarInfo with NamedTuple, with values being VarNamedTuple + - D: VarInfo with NamedTuple, with values being VarNamedVector - E: SimpleVarInfo with NamedTuples - F: SimpleVarInfo with OrderedDict From 2aa9b5ebdcc8c47208b0732f67871db9084d126a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Oct 2025 17:25:02 +0100 Subject: [PATCH 05/18] Fix type instability --- src/varinfo.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0fb77aa7a..853b737eb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -358,7 +358,7 @@ function typed_vector_varinfo( end function make_leaf_metadata((r, dist), optic) - md = Metadata() + md = Metadata(Float64) vn = VarName{:_}(optic) push!(md, vn, r, dist) return md @@ -439,8 +439,8 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) Construct an empty type unstable instance of `Metadata`. """ -function Metadata() - vals = Vector{Real}() +function Metadata(eltype=Real) + vals = Vector{eltype}() is_transformed = BitVector() return Metadata( From cf2d3234371db590d44ef19cad5820bf7f9e4b3d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Oct 2025 11:04:55 +0100 Subject: [PATCH 06/18] Various fixes --- src/contexts/transformation.jl | 11 +++-- src/varinfo.jl | 87 ++++++++++++++++++++++++++++++++-- src/varname.jl | 8 ++-- 3 files changed, 92 insertions(+), 14 deletions(-) diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..c9b024674 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -21,11 +21,12 @@ function tilde_assume!!( # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] - if is_transformed(vi, vn) - isinverse || @warn "Trying to link an already transformed variable ($vn)" - else - isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" - end + # TODO(mhauru) Warnings disabled for benchmarking purposes + # if is_transformed(vi, vn) + # isinverse || @warn "Trying to link an already transformed variable ($vn)" + # else + # isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" + # end transform = isinverse ? identity : link_transform(right) y, logjac = with_logabsdet_jacobian(transform, x) diff --git a/src/varinfo.jl b/src/varinfo.jl index 853b737eb..f1f6513bf 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -358,7 +358,7 @@ function typed_vector_varinfo( end function make_leaf_metadata((r, dist), optic) - md = Metadata(Float64) + md = Metadata(Float64, VarName{:_}) vn = VarName{:_}(optic) push!(md, vn, r, dist) return md @@ -439,13 +439,13 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) Construct an empty type unstable instance of `Metadata`. """ -function Metadata(eltype=Real) +function Metadata(eltype=Real, vntype=VarName) vals = Vector{eltype}() is_transformed = BitVector() return Metadata( - Dict{VarName,Int}(), - Vector{VarName}(), + Dict{vntype,Int}(), + Vector{vntype}(), Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), @@ -814,7 +814,7 @@ The values may or may not be transformed to Euclidean space. setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) function setval!(vi::TupleVarInfo, val, vn::VarName) main_vn, optic = split_trailing_index(vn) - return setval!(getindex(vi.metadata, main_vn), VarName{:_}(optic)) + return setval!(getindex(vi.metadata, main_vn), val, VarName{:_}(optic)) end function setval!(md::Metadata, val::AbstractVector, vn::VarName) return md.vals[getrange(md, vn)] = val @@ -1914,3 +1914,80 @@ end function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) return from_linked_vec_transform(dist) end + +function link(vi::TupleVarInfo, model::Model) + metadata = map(value -> link(value, model), vi.metadata) + return VarInfo(metadata, vi.accs) +end + +function link(metadata::Metadata, model::Model) + vns = metadata.vns + cumulative_logjac = zero(LogProbType) + + # Construct the new transformed values, and keep track of their lengths. + vals_new = map(vns) do vn + # Return early if we're already in unconstrained space. + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if is_transformed(metadata, vn) + return metadata.vals[getrange(metadata, vn)] + end + + # Transform to constrained space. + x = getindex_internal(metadata, vn) + dist = getdist(metadata, vn) + f_from_internal = from_internal_transform(metadata, vn, dist) + f_to_linked_internal = inverse(from_linked_internal_transform(metadata, vn, dist)) + f = f_to_linked_internal ∘ f_from_internal + y, logjac = with_logabsdet_jacobian(f, x) + # Vectorize value. + yvec = tovec(y) + # Accumulate the log-abs-det jacobian correction. + cumulative_logjac += logjac + # Return the vectorized transformed value. + return yvec + end + + # Determine new ranges. + ranges_new = similar(metadata.ranges) + offset = 0 + for (i, v) in enumerate(vals_new) + r_start, r_end = offset + 1, length(v) + offset + offset = r_end + ranges_new[i] = r_start:r_end + end + + # Now we just create a new metadata with the new `vals` and `ranges`. + return Metadata( + metadata.idcs, + metadata.vns, + ranges_new, + reduce(vcat, vals_new), + metadata.dists, + BitVector(fill(true, length(metadata.vns))), + ) +end + +function Base.haskey(vi::TupleVarInfo, vn::VarName) + # TODO(mhauru) Fix this to account for the index. + main_vn, optic = split_trailing_index(vn) + haskey(vi.metadata, main_vn) || return false + value = getindex(vi.metadata, main_vn) + if value isa Metadata + return haskey(value, VarName{:_}(optic)) + else + error("TODO(mhauru) Implement me") + end +end + +function BangBang.setindex!!(metadata::Metadata, val, optic) + return setindex!!(metadata, val, VarName{:_}(optic)) +end + +function BangBang.setindex!!(metadata::Metadata, (r, dist), vn::VarName) + if haskey(metadata, vn) + setval!(metadata, r, vn) + else + push!(metadata, vn, r, dist) + end + return metadata +end diff --git a/src/varname.jl b/src/varname.jl index 687427f6e..bd02e0195 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -45,7 +45,7 @@ end function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} return if Optic === typeof(identity) vn - elseif Optic isa IndexLens + elseif Optic <: Accessors.IndexLens VarName{sym}() else prefix(remove_trailing_index(unprefix(vn, sym)), sym) @@ -55,10 +55,10 @@ end function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} return if Optic === typeof(identity) (vn, identity) - elseif Optic isa IndexLens - (VarName{sym}(), Optic.index) + elseif Optic <: Accessors.IndexLens + (VarName{sym}(), getoptic(vn)) else - (prefix, index) = split_trailing_index(unprefix(vn, sym)) + (prefix, index) = split_trailing_index(unprefix(vn, VarName{sym}())) (prefix(prefix, sym), index) end end From 4c16894fdca4572e397d9a27c642e3c6127e2088 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Oct 2025 16:39:08 +0100 Subject: [PATCH 07/18] Bug fixes --- src/varinfo.jl | 7 ++++++- src/varname.jl | 8 +++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f1f6513bf..85e7a5d3f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1916,10 +1916,15 @@ function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) end function link(vi::TupleVarInfo, model::Model) - metadata = map(value -> link(value, model), vi.metadata) + metadata = link(vi.metadata, model) return VarInfo(metadata, vi.accs) end +function link(vnt::VarNamedTuple, model::Model) + new_vnt = map(value -> link(value, model), vnt) + return new_vnt +end + function link(metadata::Metadata, model::Model) vns = metadata.vns cumulative_logjac = zero(LogProbType) diff --git a/src/varname.jl b/src/varname.jl index bd02e0195..85434dd48 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -48,7 +48,9 @@ function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} elseif Optic <: Accessors.IndexLens VarName{sym}() else - prefix(remove_trailing_index(unprefix(vn, sym)), sym) + AbstractPPL.prefix( + remove_trailing_index(AbstractPPL.unprefix(vn, VarName{sym}())), VarName{sym}() + ) end end @@ -58,7 +60,7 @@ function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic} elseif Optic <: Accessors.IndexLens (VarName{sym}(), getoptic(vn)) else - (prefix, index) = split_trailing_index(unprefix(vn, VarName{sym}())) - (prefix(prefix, sym), index) + (pref, index) = split_trailing_index(AbstractPPL.unprefix(vn, VarName{sym}())) + (AbstractPPL.prefix(pref, VarName{sym}()), index) end end From 0e51e41c5e04d0a0057a952135a243fdabbefde5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 27 Oct 2025 14:03:12 +0000 Subject: [PATCH 08/18] Add some questions/issues to the doc page --- docs/src/internals/varnamedtuple.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index c52bed974..cb0a9c4c3 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -104,3 +104,9 @@ I think the two big questions are: - Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up. I'll try to derisk these early on in this PR. + +## Questions / issues + +* People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. +* When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? +* Do `Colon` indices cause any extra trouble for the leafnodes? From 6beca1a74f53a8bc8f5a2e0fb57770761879ff8e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Nov 2025 20:33:46 +0000 Subject: [PATCH 09/18] Actually add the src/varnamedtuple.jl file --- docs/src/internals/varnamedtuple.md | 6 +- src/varnamedtuple.jl | 270 ++++++++++++++++++++++++++++ 2 files changed, 273 insertions(+), 3 deletions(-) create mode 100644 src/varnamedtuple.jl diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index cb0a9c4c3..9f7a84cdb 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -107,6 +107,6 @@ I'll try to derisk these early on in this PR. ## Questions / issues -* People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. -* When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? -* Do `Colon` indices cause any extra trouble for the leafnodes? + - People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. + - When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? + - Do `Colon` indices cause any extra trouble for the leafnodes? diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl new file mode 100644 index 000000000..1a38b4782 --- /dev/null +++ b/src/varnamedtuple.jl @@ -0,0 +1,270 @@ +module VarNamedTuples + +using AbstractPPL +using BangBang +using Accessors +using DynamicPPL: _compose_no_identity + +export VarNamedTuple + +# @varname(a.b[3].c[:].d) +# +# VarNamedTuple( +# (; a=(; b=[ +# (; c=[ +# (; d=...), +# (; d=...), +# (; d=...), +# ]), +# (; c=[ +# (; d=...), +# (; d=...), +# (; d=...), +# ]), +# (; c=[ +# (; d=...), +# (; d=...), +# (; d=...), +# ]), +# )) +#) + +struct VarNamedTuple{T<:Function,Names,Values} + data::NamedTuple{Names,Values} + make_leaf::T +end + +struct IndexDict{T<:Function,Keys,Values} + data::Dict{Keys,Values} + make_leaf::T +end + +function make_leaf_raw(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_raw) +end +make_leaf_raw(value, ::typeof(identity)) = value +function make_leaf_raw(value, optic::IndexLens) + return IndexDict(Dict(optic.indices => value), make_leaf_raw) +end +function make_leaf_raw(value, optic::ComposedFunction) + sub = make_leaf_raw(value, optic.outer) + return make_leaf_raw(sub, optic.inner) +end + +VarNamedTuple() = VarNamedTuple((;), make_leaf_raw) + +function Base.show(io::IO, vnt::VarNamedTuple) + print(io, "(") + for (i, (name, value)) in enumerate(pairs(vnt.data)) + if i > 1 + print(io, ", ") + end + print(io, name, " -> ") + print(io, value) + end + return print(io, ")") +end + +function Base.show(io::IO, vnt::IndexDict) + return print(io, vnt.data) +end + +Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] + +function varname_to_lens(name::VarName{S}) where {S} + return _compose_no_identity(getoptic(name), PropertyLens{S}()) +end +function Base.getindex(vnt::VarNamedTuple, name::VarName) + return getindex(vnt, varname_to_lens(name)) +end +function Base.getindex(vnt::VarNamedTuple, lens::ComposedFunction) + subdata = getindex(vnt, lens.inner) + return getindex(subdata, lens.outer) +end +function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} + return getindex(vnt.data, S) +end +function Base.getindex(vnt::IndexDict, lens::IndexLens) + return getindex(vnt.data, lens.indices) +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName) + return haskey(vnt, varname_to_lens(name)) +end + +Base.haskey(vnt::VarNamedTuple, ::typeof(identity)) = true + +function Base.haskey(vnt::VarNamedTuple, lens::ComposedFunction) + return haskey(vnt, lens.inner) && haskey(getindex(vnt, lens.inner), lens.outer) +end + +Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) +Base.haskey(vnt::IndexDict, lens::IndexLens) = haskey(vnt.data, lens.indices) +Base.haskey(vnt::VarNamedTuple, lens::IndexLens) = false +Base.haskey(vnt::IndexDict, lens::PropertyLens) = false + +# TODO(mhauru) This is type piracy. +Base.getindex(arr::AbstractArray, lens::IndexLens) = getindex(arr, lens.indices...) + +# TODO(mhauru) This is type piracy. +function BangBang.setindex!!(arr::AbstractArray, value, lens::IndexLens) + return BangBang.setindex!!(arr, value, lens.indices...) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) + return BangBang.setindex!!(vnt, value, varname_to_lens(name)) +end + +function BangBang.setindex!!( + vnt::VarNamedTuple, value, lens::ComposedFunction{Outer,Inner} +) where {Outer,Inner} + sub = if haskey(vnt, lens.inner) + BangBang.setindex!!(lens.inner(vnt.data), value, lens.outer) + else + vnt.make_leaf(value, lens.outer) + end + return BangBang.setindex!!(vnt, sub, lens.inner) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} + return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) +end + +function BangBang.setindex!!(vnt::IndexDict, value, lens::IndexLens) + return setindex!(vnt.data, value, lens.indices) +end + +# function BangBang.setindex!!( +# vnt::VarNamedTuple, value, name::{S,Optic} +# ) where {S,Optic} +# new_data = if haskey(vnt.data, S) +# if Optic === typeof(identity) +# BangBang.setindex!!(vnt.data, vnt.make_leaf(value, getoptic(name)), S) +# elseif Optic <: IndexLens +# new_subdata = BangBang.setindex!!(vnt.data[S], value, getoptic(name)) +# BangBang.setindex!!(vnt.data, new_subdata, S) +# else +# new_subdata = BangBang.setindex!!( +# vnt.data[S], value, AbstractPPL.unprefix(name, VarName{S}()) +# ) +# BangBang.setindex!!(vnt.data, new_subdata, S) +# end +# else +# new_subdata = if Optic === typeof(identity) || Optic <: IndexLens +# vnt.make_leaf(value, getoptic(name)) +# # if Optic === typeof(identity) +# # BangBang.setindex!!(vnt.data, value, S) +# # elseif Optic <: IndexLens +# # new_subdata = BangBang.setindex!!(Dict{Union{},Union{}}(), value, getoptic(name).indices...) +# # BangBang.setindex!!(vnt.data, new_subdata, S) +# else +# BangBang.setindex!!( +# VarNamedTuple((;), vnt.make_leaf), +# value, +# AbstractPPL.unprefix(name, VarName{S}()), +# ) +# end +# BangBang.setindex!!(vnt.data, new_subdata, S) +# end +# return VarNamedTuple(new_data, vnt.make_leaf) +# end + +function apply(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt.data, name.name) + throw(KeyError(repr(name))) + end + subdata = getindex(vnt, name) + new_subdata = func(subdata) + return BangBang.setindex!!(vnt, new_subdata, name) +end + +function Base.map(func, vnt::VarNamedTuple) + new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) + return VarNamedTuple(new_data, vnt.make_leaf) +end + +function Base.keys(vnt::VarNamedTuple) + result = () + for sym in keys(vnt.data) + subdata = vnt.data[sym] + if subdata isa VarNamedTuple + subkeys = keys(subdata) + result = ( + (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)..., result... + ) + else + result = (VarName{sym}(), result...) + end + subkeys = keys(vnt.data[sym]) + end + return result +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} + if !haskey(vnt.data, S) + return false + end + subdata = vnt.data[S] + return if Optic === typeof(identity) + true + elseif Optic <: IndexLens + try + AbstractPPL.getoptic(name)(subdata) + true + catch e + if e isa BoundsError || e isa KeyError + false + else + rethrow(e) + end + end + else + haskey(subdata, AbstractPPL.unprefix(name, VarName{S}())) + end +end + +end + +# module AdHocTests +# +# using DynamicPPL: VarNamedTuples, @varname +# using BangBang +# +# vnt = VarNamedTuples.VarNamedTuple() +# display(vnt) +# +# vnt = setindex!!(vnt, 32.0, @varname(a)) +# println("a = $(vnt[@varname(a)])") +# display(vnt) +# +# vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) +# println("b[2] = $(vnt[@varname(b[2])])") +# display(vnt) +# +# vnt = setindex!!(vnt, 64.0, @varname(a)) +# display(vnt) +# +# vnt = setindex!!(vnt, 15, @varname(b[2])) +# display(vnt) +# +# vnt = setindex!!(vnt, [10], @varname(c.x.y)) +# println("c.x = $(vnt[@varname(c.x)])") +# display(vnt) +# +# vnt = setindex!!(vnt, 11, @varname(c.x.y[1])) +# display(vnt) +# +# vnt = setindex!!(vnt, -1.0, @varname(d[4])) +# display(vnt) +# +# vnt = setindex!!(vnt, -2.0, @varname(d[4])) +# display(vnt) +# +# vnt = setindex!!(vnt, -3.0, @varname(d[5])) +# display(vnt) +# +# println("d = $(vnt[@varname(d)])") +# +# vnt = setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i)) +# display(vnt) +# end From f8665f96600e11209c335117fd9baf7546f26b89 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Nov 2025 20:51:58 +0000 Subject: [PATCH 10/18] Clean up, varnamedtuple tests --- src/varnamedtuple.jl | 126 +++++------------------------------------- test/runtests.jl | 1 + test/varnamedtuple.jl | 42 ++++++++++++++ 3 files changed, 57 insertions(+), 112 deletions(-) create mode 100644 test/varnamedtuple.jl diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 1a38b4782..17244b3bc 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1,3 +1,4 @@ +# TODO(mhauru) This module should probably be moved to AbstractPPL. module VarNamedTuples using AbstractPPL @@ -7,28 +8,6 @@ using DynamicPPL: _compose_no_identity export VarNamedTuple -# @varname(a.b[3].c[:].d) -# -# VarNamedTuple( -# (; a=(; b=[ -# (; c=[ -# (; d=...), -# (; d=...), -# (; d=...), -# ]), -# (; c=[ -# (; d=...), -# (; d=...), -# (; d=...), -# ]), -# (; c=[ -# (; d=...), -# (; d=...), -# (; d=...), -# ]), -# )) -#) - struct VarNamedTuple{T<:Function,Names,Values} data::NamedTuple{Names,Values} make_leaf::T @@ -65,8 +44,8 @@ function Base.show(io::IO, vnt::VarNamedTuple) return print(io, ")") end -function Base.show(io::IO, vnt::IndexDict) - return print(io, vnt.data) +function Base.show(io::IO, id::IndexDict) + return print(io, id.data) end Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] @@ -74,18 +53,19 @@ Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] function varname_to_lens(name::VarName{S}) where {S} return _compose_no_identity(getoptic(name), PropertyLens{S}()) end + function Base.getindex(vnt::VarNamedTuple, name::VarName) return getindex(vnt, varname_to_lens(name)) end -function Base.getindex(vnt::VarNamedTuple, lens::ComposedFunction) - subdata = getindex(vnt, lens.inner) +function Base.getindex(x::Union{VarNamedTuple,IndexDict}, lens::ComposedFunction) + subdata = getindex(x, lens.inner) return getindex(subdata, lens.outer) end function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} return getindex(vnt.data, S) end -function Base.getindex(vnt::IndexDict, lens::IndexLens) - return getindex(vnt.data, lens.indices) +function Base.getindex(id::IndexDict, lens::IndexLens) + return getindex(id.data, lens.indices) end function Base.haskey(vnt::VarNamedTuple, name::VarName) @@ -99,9 +79,9 @@ function Base.haskey(vnt::VarNamedTuple, lens::ComposedFunction) end Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) -Base.haskey(vnt::IndexDict, lens::IndexLens) = haskey(vnt.data, lens.indices) -Base.haskey(vnt::VarNamedTuple, lens::IndexLens) = false -Base.haskey(vnt::IndexDict, lens::PropertyLens) = false +Base.haskey(id::IndexDict, lens::IndexLens) = haskey(id.data, lens.indices) +Base.haskey(::VarNamedTuple, ::IndexLens) = false +Base.haskey(::IndexDict, ::PropertyLens) = false # TODO(mhauru) This is type piracy. Base.getindex(arr::AbstractArray, lens::IndexLens) = getindex(arr, lens.indices...) @@ -130,45 +110,11 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) end -function BangBang.setindex!!(vnt::IndexDict, value, lens::IndexLens) - return setindex!(vnt.data, value, lens.indices) +function BangBang.setindex!!(id::IndexDict, value, lens::IndexLens) + setindex!(id.data, value, lens.indices) + return id end -# function BangBang.setindex!!( -# vnt::VarNamedTuple, value, name::{S,Optic} -# ) where {S,Optic} -# new_data = if haskey(vnt.data, S) -# if Optic === typeof(identity) -# BangBang.setindex!!(vnt.data, vnt.make_leaf(value, getoptic(name)), S) -# elseif Optic <: IndexLens -# new_subdata = BangBang.setindex!!(vnt.data[S], value, getoptic(name)) -# BangBang.setindex!!(vnt.data, new_subdata, S) -# else -# new_subdata = BangBang.setindex!!( -# vnt.data[S], value, AbstractPPL.unprefix(name, VarName{S}()) -# ) -# BangBang.setindex!!(vnt.data, new_subdata, S) -# end -# else -# new_subdata = if Optic === typeof(identity) || Optic <: IndexLens -# vnt.make_leaf(value, getoptic(name)) -# # if Optic === typeof(identity) -# # BangBang.setindex!!(vnt.data, value, S) -# # elseif Optic <: IndexLens -# # new_subdata = BangBang.setindex!!(Dict{Union{},Union{}}(), value, getoptic(name).indices...) -# # BangBang.setindex!!(vnt.data, new_subdata, S) -# else -# BangBang.setindex!!( -# VarNamedTuple((;), vnt.make_leaf), -# value, -# AbstractPPL.unprefix(name, VarName{S}()), -# ) -# end -# BangBang.setindex!!(vnt.data, new_subdata, S) -# end -# return VarNamedTuple(new_data, vnt.make_leaf) -# end - function apply(func, vnt::VarNamedTuple, name::VarName) if !haskey(vnt.data, name.name) throw(KeyError(repr(name))) @@ -224,47 +170,3 @@ function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} end end - -# module AdHocTests -# -# using DynamicPPL: VarNamedTuples, @varname -# using BangBang -# -# vnt = VarNamedTuples.VarNamedTuple() -# display(vnt) -# -# vnt = setindex!!(vnt, 32.0, @varname(a)) -# println("a = $(vnt[@varname(a)])") -# display(vnt) -# -# vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) -# println("b[2] = $(vnt[@varname(b[2])])") -# display(vnt) -# -# vnt = setindex!!(vnt, 64.0, @varname(a)) -# display(vnt) -# -# vnt = setindex!!(vnt, 15, @varname(b[2])) -# display(vnt) -# -# vnt = setindex!!(vnt, [10], @varname(c.x.y)) -# println("c.x = $(vnt[@varname(c.x)])") -# display(vnt) -# -# vnt = setindex!!(vnt, 11, @varname(c.x.y[1])) -# display(vnt) -# -# vnt = setindex!!(vnt, -1.0, @varname(d[4])) -# display(vnt) -# -# vnt = setindex!!(vnt, -2.0, @varname(d[4])) -# display(vnt) -# -# vnt = setindex!!(vnt, -3.0, @varname(d[5])) -# display(vnt) -# -# println("d = $(vnt[@varname(d)])") -# -# vnt = setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i)) -# display(vnt) -# end diff --git a/test/runtests.jl b/test/runtests.jl index b6a3f7bf6..b56c00b68 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,6 +58,7 @@ include("test_util.jl") include("accumulators.jl") include("compiler.jl") include("varnamedvector.jl") + include("varnamedtuple.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl new file mode 100644 index 000000000..a80861ccb --- /dev/null +++ b/test/varnamedtuple.jl @@ -0,0 +1,42 @@ +module VarNamedTupleTests + +using Test: @testset, @test +using DynamicPPL: @varname, VarNamedTuple +using BangBang: setindex!! + +@testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = setindex!!(vnt, 32.0, @varname(a)) + @test getindex(vnt, @varname(a)) == 32.0 + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + @test getindex(vnt, @varname(b)) == [1, 2, 3] + + vnt = setindex!!(vnt, 64.0, @varname(a)) + @test getindex(vnt, @varname(a)) == 64.0 + + vnt = setindex!!(vnt, 15, @varname(b[2])) + @test getindex(vnt, @varname(b)) == [1, 15, 3] + @test getindex(vnt, @varname(b[2])) == 15 + + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + @test getindex(vnt, @varname(c.x.y)) == [10] + + vnt = setindex!!(vnt, 11, @varname(c.x.y[1])) + @test getindex(vnt, @varname(c.x.y)) == [11] + @test getindex(vnt, @varname(c.x.y[1])) == 11 + + vnt = setindex!!(vnt, -1.0, @varname(d[4])) + @test getindex(vnt, @varname(d[4])) == -1.0 + + vnt = setindex!!(vnt, -2.0, @varname(d[4])) + @test getindex(vnt, @varname(d[4])) == -2.0 + + vnt = setindex!!(vnt, -3.0, @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == -3.0 + + vnt = setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i)) + @test getindex(vnt, @varname(e.f[3].g.h[2].i)) == 1.0 +end + +end From e666bd2b8fb2fabfd1471c6a7ca1029c6e647f29 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Nov 2025 21:14:17 +0000 Subject: [PATCH 11/18] Test and fix type stability --- src/varnamedtuple.jl | 10 ++++++---- test/varnamedtuple.jl | 46 +++++++++++++++++++++---------------------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 17244b3bc..19f610080 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -95,9 +95,7 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) return BangBang.setindex!!(vnt, value, varname_to_lens(name)) end -function BangBang.setindex!!( - vnt::VarNamedTuple, value, lens::ComposedFunction{Outer,Inner} -) where {Outer,Inner} +function BangBang.setindex!!(vnt::VarNamedTuple, value, lens::ComposedFunction) sub = if haskey(vnt, lens.inner) BangBang.setindex!!(lens.inner(vnt.data), value, lens.outer) else @@ -107,7 +105,11 @@ function BangBang.setindex!!( end function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} - return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) + # I would like this to just read + # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) + # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the + # below? + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) end function BangBang.setindex!!(id::IndexDict, value, lens::IndexLens) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index a80861ccb..124105027 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1,42 +1,42 @@ module VarNamedTupleTests -using Test: @testset, @test +using Test: @inferred, @testset, @test using DynamicPPL: @varname, VarNamedTuple using BangBang: setindex!! @testset "Basic sets and gets" begin vnt = VarNamedTuple() - vnt = setindex!!(vnt, 32.0, @varname(a)) - @test getindex(vnt, @varname(a)) == 32.0 + vnt = @inferred setindex!!(vnt, 32.0, @varname(a)) + @test @inferred getindex(vnt, @varname(a)) == 32.0 - vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) - @test getindex(vnt, @varname(b)) == [1, 2, 3] + vnt = @inferred setindex!!(vnt, [1, 2, 3], @varname(b)) + @test @inferred getindex(vnt, @varname(b)) == [1, 2, 3] - vnt = setindex!!(vnt, 64.0, @varname(a)) - @test getindex(vnt, @varname(a)) == 64.0 + vnt = @inferred setindex!!(vnt, 64.0, @varname(a)) + @test @inferred getindex(vnt, @varname(a)) == 64.0 - vnt = setindex!!(vnt, 15, @varname(b[2])) - @test getindex(vnt, @varname(b)) == [1, 15, 3] - @test getindex(vnt, @varname(b[2])) == 15 + vnt = @inferred setindex!!(vnt, 15, @varname(b[2])) + @test @inferred getindex(vnt, @varname(b)) == [1, 15, 3] + @test @inferred getindex(vnt, @varname(b[2])) == 15 - vnt = setindex!!(vnt, [10], @varname(c.x.y)) - @test getindex(vnt, @varname(c.x.y)) == [10] + vnt = @inferred setindex!!(vnt, [10], @varname(c.x.y)) + @test @inferred getindex(vnt, @varname(c.x.y)) == [10] - vnt = setindex!!(vnt, 11, @varname(c.x.y[1])) - @test getindex(vnt, @varname(c.x.y)) == [11] - @test getindex(vnt, @varname(c.x.y[1])) == 11 + vnt = @inferred setindex!!(vnt, 11, @varname(c.x.y[1])) + @test @inferred getindex(vnt, @varname(c.x.y)) == [11] + @test @inferred getindex(vnt, @varname(c.x.y[1])) == 11 - vnt = setindex!!(vnt, -1.0, @varname(d[4])) - @test getindex(vnt, @varname(d[4])) == -1.0 + vnt = @inferred setindex!!(vnt, -1.0, @varname(d[4])) + @test @inferred getindex(vnt, @varname(d[4])) == -1.0 - vnt = setindex!!(vnt, -2.0, @varname(d[4])) - @test getindex(vnt, @varname(d[4])) == -2.0 + vnt = @inferred setindex!!(vnt, -2.0, @varname(d[4])) + @test @inferred getindex(vnt, @varname(d[4])) == -2.0 - vnt = setindex!!(vnt, -3.0, @varname(d[5])) - @test getindex(vnt, @varname(d[5])) == -3.0 + vnt = @inferred setindex!!(vnt, -3.0, @varname(d[5])) + @test @inferred getindex(vnt, @varname(d[5])) == -3.0 - vnt = setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i)) - @test getindex(vnt, @varname(e.f[3].g.h[2].i)) == 1.0 + vnt = @inferred setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i)) + @test @inferred getindex(vnt, @varname(e.f[3].g.h[2].i)) == 1.0 end end From 747d2e41e72cb696f7c82ab48c45a92d99e0a731 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Nov 2025 21:21:43 +0000 Subject: [PATCH 12/18] Fixes --- src/varnamedtuple.jl | 3 +-- test/varnamedtuple.jl | 46 ++++++++++++++++++++++--------------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 19f610080..a04705462 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -113,8 +113,7 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where end function BangBang.setindex!!(id::IndexDict, value, lens::IndexLens) - setindex!(id.data, value, lens.indices) - return id + return IndexDict(setindex!!(id.data, value, lens.indices), id.make_leaf) end function apply(func, vnt::VarNamedTuple, name::VarName) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 124105027..50c0ea2ac 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -6,37 +6,39 @@ using BangBang: setindex!! @testset "Basic sets and gets" begin vnt = VarNamedTuple() - vnt = @inferred setindex!!(vnt, 32.0, @varname(a)) - @test @inferred getindex(vnt, @varname(a)) == 32.0 + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 - vnt = @inferred setindex!!(vnt, [1, 2, 3], @varname(b)) - @test @inferred getindex(vnt, @varname(b)) == [1, 2, 3] + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] - vnt = @inferred setindex!!(vnt, 64.0, @varname(a)) - @test @inferred getindex(vnt, @varname(a)) == 64.0 + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 - vnt = @inferred setindex!!(vnt, 15, @varname(b[2])) - @test @inferred getindex(vnt, @varname(b)) == [1, 15, 3] - @test @inferred getindex(vnt, @varname(b[2])) == 15 + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 - vnt = @inferred setindex!!(vnt, [10], @varname(c.x.y)) - @test @inferred getindex(vnt, @varname(c.x.y)) == [10] + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] - vnt = @inferred setindex!!(vnt, 11, @varname(c.x.y[1])) - @test @inferred getindex(vnt, @varname(c.x.y)) == [11] - @test @inferred getindex(vnt, @varname(c.x.y[1])) == 11 + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 - vnt = @inferred setindex!!(vnt, -1.0, @varname(d[4])) - @test @inferred getindex(vnt, @varname(d[4])) == -1.0 + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 - vnt = @inferred setindex!!(vnt, -2.0, @varname(d[4])) - @test @inferred getindex(vnt, @varname(d[4])) == -2.0 + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 - vnt = @inferred setindex!!(vnt, -3.0, @varname(d[5])) - @test @inferred getindex(vnt, @varname(d[5])) == -3.0 + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type instability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" - vnt = @inferred setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i)) - @test @inferred getindex(vnt, @varname(e.f[3].g.h[2].i)) == 1.0 + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 end end From 79fb5e19f39c0ca28fe7711b8b2c303bdb6f209d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 12 Nov 2025 21:45:16 +0000 Subject: [PATCH 13/18] Add a missing method --- src/varnamedtuple.jl | 6 ++++-- test/varnamedtuple.jl | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index a04705462..4dec92154 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -95,9 +95,11 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) return BangBang.setindex!!(vnt, value, varname_to_lens(name)) end -function BangBang.setindex!!(vnt::VarNamedTuple, value, lens::ComposedFunction) +function BangBang.setindex!!( + vnt::Union{VarNamedTuple,IndexDict}, value, lens::ComposedFunction +) sub = if haskey(vnt, lens.inner) - BangBang.setindex!!(lens.inner(vnt.data), value, lens.outer) + BangBang.setindex!!(getindex(vnt, lens.inner), value, lens.outer) else vnt.make_leaf(value, lens.outer) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 50c0ea2ac..2b99b2ed0 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -33,12 +33,15 @@ using BangBang: setindex!! @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 # These can't be @inferred because `d` now has an abstract element type. Note that this - # does not ruin type instability for other varnames that don't involve `d`. + # does not ruin type stability for other varnames that don't involve `d`. vnt = setindex!!(vnt, "a", @varname(d[5])) @test getindex(vnt, @varname(d[5])) == "a" vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 end end From 9f4d47680014f8d0f9f719c4fbd0856c2c6e79f4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 12:48:11 +0000 Subject: [PATCH 14/18] Introduce IndexArray --- src/varnamedtuple.jl | 94 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 83 insertions(+), 11 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 4dec92154..63bc5d551 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -18,19 +18,91 @@ struct IndexDict{T<:Function,Keys,Values} make_leaf::T end -function make_leaf_raw(value, ::PropertyLens{S}) where {S} - return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_raw) +struct IndexArray{T<:Function,ElType,numdims} + data::Array{ElType,numdims} + mask::Array{Bool,numdims} + make_leaf::T +end + +function IndexArray(eltype, num_dims, make_leaf) + dims = ntuple(_ -> 0, num_dims) + data = Array{eltype,num_dims}(undef, dims) + mask = fill(false, dims) + return IndexArray(data, mask, make_leaf) +end + +_length_needed(i::Integer) = i +_length_needed(r::UnitRange) = last(r) +_length_needed(::Colon) = 0 + +function _resize_indexarray(iarr::IndexArray, inds) + # Resize arrays to accommodate new indices. + new_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + # Generic multidimensional Arrays can not be resized, so we need to make a new one. + # See https://github.com/JuliaLang/julia/issues/37900 + new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) + new_mask = fill(false, new_sizes) + for i in eachindex(iarr.data) + @inbounds new_data[i] = iarr.data[i] + @inbounds new_mask[i] = iarr.mask[i] + end + return IndexArray(new_data, new_mask, iarr.make_leaf) +end + +function BangBang.setindex!!(iarr::IndexArray, value, lens::IndexLens) + inds = lens.indices + iarr = if checkbounds(Bool, iarr.mask, inds...) + iarr + else + _resize_indexarray(iarr, inds) + end + new_data = setindex!!(iarr.data, value, inds...) + new_mask = setindex!!(iarr.mask, true, inds...) + return IndexArray(new_data, new_mask, iarr.make_leaf) +end + +function Base.getindex(iarr::IndexArray, lens::IndexLens) + if !haskey(iarr, lens) + throw(BoundsError("No value set at indices $(lens)")) + end + inds = lens.indices + return getindex(iarr.data, inds...) +end + +function Base.haskey(iarr::IndexArray, lens::IndexLens) + inds = lens.indices + return checkbounds(Bool, iarr.mask, inds...) && + all(@inbounds(getindex(iarr.mask, inds...))) +end + +function make_leaf_array(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) +end +make_leaf_array(value, ::typeof(identity)) = value +function make_leaf_array(value, optic::ComposedFunction) + sub = make_leaf_array(value, optic.outer) + return make_leaf_array(sub, optic.inner) +end + +function make_leaf_array(value, optic::IndexLens) + num_inds = length(optic.indices) + iarr = IndexArray(typeof(value), num_inds, make_leaf_array) + return setindex!!(iarr, value, optic) +end + +function make_leaf_dict(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_dict) end -make_leaf_raw(value, ::typeof(identity)) = value -function make_leaf_raw(value, optic::IndexLens) - return IndexDict(Dict(optic.indices => value), make_leaf_raw) +make_leaf_dict(value, ::typeof(identity)) = value +function make_leaf_dict(value, optic::ComposedFunction) + sub = make_leaf_dict(value, optic.outer) + return make_leaf_dict(sub, optic.inner) end -function make_leaf_raw(value, optic::ComposedFunction) - sub = make_leaf_raw(value, optic.outer) - return make_leaf_raw(sub, optic.inner) +function make_leaf_dict(value, optic::IndexLens) + return IndexDict(Dict(optic.indices => value), make_leaf_dict) end -VarNamedTuple() = VarNamedTuple((;), make_leaf_raw) +VarNamedTuple() = VarNamedTuple((;), make_leaf_array) function Base.show(io::IO, vnt::VarNamedTuple) print(io, "(") @@ -57,7 +129,7 @@ end function Base.getindex(vnt::VarNamedTuple, name::VarName) return getindex(vnt, varname_to_lens(name)) end -function Base.getindex(x::Union{VarNamedTuple,IndexDict}, lens::ComposedFunction) +function Base.getindex(x::Union{VarNamedTuple,IndexDict,IndexArray}, lens::ComposedFunction) subdata = getindex(x, lens.inner) return getindex(subdata, lens.outer) end @@ -96,7 +168,7 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) end function BangBang.setindex!!( - vnt::Union{VarNamedTuple,IndexDict}, value, lens::ComposedFunction + vnt::Union{VarNamedTuple,IndexDict,IndexArray}, value, lens::ComposedFunction ) sub = if haskey(vnt, lens.inner) BangBang.setindex!!(getindex(vnt, lens.inner), value, lens.outer) From cbe29930cb51127ddd3bb019f9631c486ca76d8f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 15:22:03 +0000 Subject: [PATCH 15/18] Ban colons and improve VNT --- src/varnamedtuple.jl | 89 ++++++++++++++++++++++++++++--------------- test/varnamedtuple.jl | 20 ++++++++++ 2 files changed, 78 insertions(+), 31 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 63bc5d551..5a34a3b8e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,6 +8,12 @@ using DynamicPPL: _compose_no_identity export VarNamedTuple +_has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) + +function _is_multiindex(::IndexLens{T}) where {T} + return any(x <: UnitRange || x <: Colon for x in T.parameters) +end + struct VarNamedTuple{T<:Function,Names,Values} data::NamedTuple{Names,Values} make_leaf::T @@ -35,6 +41,7 @@ _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) _length_needed(::Colon) = 0 +# TODO(mhauru) Implement a simpler version of this for Vectors as a performance optimization. function _resize_indexarray(iarr::IndexArray, inds) # Resize arrays to accommodate new indices. new_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) @@ -49,28 +56,42 @@ function _resize_indexarray(iarr::IndexArray, inds) return IndexArray(new_data, new_mask, iarr.make_leaf) end -function BangBang.setindex!!(iarr::IndexArray, value, lens::IndexLens) - inds = lens.indices +function BangBang.setindex!!(iarr::IndexArray, value, optic::IndexLens) + if _has_colon(optic) + # TODO(mhauru) This could be implemented, by getting size information from `value`. + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices iarr = if checkbounds(Bool, iarr.mask, inds...) iarr else _resize_indexarray(iarr, inds) end new_data = setindex!!(iarr.data, value, inds...) - new_mask = setindex!!(iarr.mask, true, inds...) - return IndexArray(new_data, new_mask, iarr.make_leaf) + if _is_multiindex(optic) + iarr.mask[inds...] .= true + else + iarr.mask[inds...] = true + end + return IndexArray(new_data, iarr.mask, iarr.make_leaf) end -function Base.getindex(iarr::IndexArray, lens::IndexLens) - if !haskey(iarr, lens) - throw(BoundsError("No value set at indices $(lens)")) +function Base.getindex(iarr::IndexArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + if !haskey(iarr, optic) + throw(BoundsError("No value set at indices $(optic)")) end - inds = lens.indices + inds = optic.indices return getindex(iarr.data, inds...) end -function Base.haskey(iarr::IndexArray, lens::IndexLens) - inds = lens.indices +function Base.haskey(iarr::IndexArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices return checkbounds(Bool, iarr.mask, inds...) && all(@inbounds(getindex(iarr.mask, inds...))) end @@ -84,9 +105,13 @@ function make_leaf_array(value, optic::ComposedFunction) return make_leaf_array(sub, optic.inner) end -function make_leaf_array(value, optic::IndexLens) - num_inds = length(optic.indices) - iarr = IndexArray(typeof(value), num_inds, make_leaf_array) +function make_leaf_array(value, optic::IndexLens{T}) where {T} + inds = optic.indices + num_inds = length(inds) + # Check if any of the indices are ranges or colons. If yes, value needs to be an + # AbstractArray. Otherwise it needs to be an individual value. + et = _is_multiindex(optic) ? eltype(value) : typeof(value) + iarr = IndexArray(et, num_inds, make_leaf_array) return setindex!!(iarr, value, optic) end @@ -129,15 +154,17 @@ end function Base.getindex(vnt::VarNamedTuple, name::VarName) return getindex(vnt, varname_to_lens(name)) end -function Base.getindex(x::Union{VarNamedTuple,IndexDict,IndexArray}, lens::ComposedFunction) - subdata = getindex(x, lens.inner) - return getindex(subdata, lens.outer) +function Base.getindex( + x::Union{VarNamedTuple,IndexDict,IndexArray}, optic::ComposedFunction +) + subdata = getindex(x, optic.inner) + return getindex(subdata, optic.outer) end function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} return getindex(vnt.data, S) end -function Base.getindex(id::IndexDict, lens::IndexLens) - return getindex(id.data, lens.indices) +function Base.getindex(id::IndexDict, optic::IndexLens) + return getindex(id.data, optic.indices) end function Base.haskey(vnt::VarNamedTuple, name::VarName) @@ -146,21 +173,21 @@ end Base.haskey(vnt::VarNamedTuple, ::typeof(identity)) = true -function Base.haskey(vnt::VarNamedTuple, lens::ComposedFunction) - return haskey(vnt, lens.inner) && haskey(getindex(vnt, lens.inner), lens.outer) +function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) + return haskey(vnt, optic.inner) && haskey(getindex(vnt, optic.inner), optic.outer) end Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) -Base.haskey(id::IndexDict, lens::IndexLens) = haskey(id.data, lens.indices) +Base.haskey(id::IndexDict, optic::IndexLens) = haskey(id.data, optic.indices) Base.haskey(::VarNamedTuple, ::IndexLens) = false Base.haskey(::IndexDict, ::PropertyLens) = false # TODO(mhauru) This is type piracy. -Base.getindex(arr::AbstractArray, lens::IndexLens) = getindex(arr, lens.indices...) +Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) # TODO(mhauru) This is type piracy. -function BangBang.setindex!!(arr::AbstractArray, value, lens::IndexLens) - return BangBang.setindex!!(arr, value, lens.indices...) +function BangBang.setindex!!(arr::AbstractArray, value, optic::IndexLens) + return BangBang.setindex!!(arr, value, optic.indices...) end function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) @@ -168,14 +195,14 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) end function BangBang.setindex!!( - vnt::Union{VarNamedTuple,IndexDict,IndexArray}, value, lens::ComposedFunction + vnt::Union{VarNamedTuple,IndexDict,IndexArray}, value, optic::ComposedFunction ) - sub = if haskey(vnt, lens.inner) - BangBang.setindex!!(getindex(vnt, lens.inner), value, lens.outer) + sub = if haskey(vnt, optic.inner) + BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) else - vnt.make_leaf(value, lens.outer) + vnt.make_leaf(value, optic.outer) end - return BangBang.setindex!!(vnt, sub, lens.inner) + return BangBang.setindex!!(vnt, sub, optic.inner) end function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} @@ -186,8 +213,8 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) end -function BangBang.setindex!!(id::IndexDict, value, lens::IndexLens) - return IndexDict(setindex!!(id.data, value, lens.indices), id.make_leaf) +function BangBang.setindex!!(id::IndexDict, value, optic::IndexLens) + return IndexDict(setindex!!(id.data, value, optic.indices), id.make_leaf) end function apply(func, vnt::VarNamedTuple, name::VarName) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 2b99b2ed0..cfa4d9d3c 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -42,6 +42,26 @@ using BangBang: setindex!! vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 10]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 10]))) == fill(2.0, 2) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) end end From 7107c0bb82440458def39fd93ac3a5ee73cc0e3a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 16:18:10 +0000 Subject: [PATCH 16/18] More tests, performance, errors --- src/varnamedtuple.jl | 37 +++++++++++++++++++++++++++++-------- test/varnamedtuple.jl | 28 +++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 5a34a3b8e..10d7ea289 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -41,27 +41,45 @@ _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) _length_needed(::Colon) = 0 -# TODO(mhauru) Implement a simpler version of this for Vectors as a performance optimization. function _resize_indexarray(iarr::IndexArray, inds) - # Resize arrays to accommodate new indices. new_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) new_mask = fill(false, new_sizes) - for i in eachindex(iarr.data) - @inbounds new_data[i] = iarr.data[i] - @inbounds new_mask[i] = iarr.mask[i] + # Note that we have to use CartesianIndices instead of eachindex, because the latter + # may use a linear index that does not match between the old and the new arrays. + for i in CartesianIndices(iarr.data) + mask_val = iarr.mask[i] + @inbounds new_mask[i] = mask_val + if mask_val + @inbounds new_data[i] = iarr.data[i] + end end return IndexArray(new_data, new_mask, iarr.make_leaf) end +# The below implements the same functionality as above, but more performantly for 1D arrays. +function _resize_indexarray(iarr::IndexArray{T,Eltype,1}, (ind,)) where {T,Eltype} + # Resize arrays to accommodate new indices. + old_size = size(iarr.data, 1) + new_size = max(old_size, _length_needed(ind)) + resize!(iarr.data, new_size) + resize!(iarr.mask, new_size) + @inbounds iarr.mask[(old_size + 1):new_size] .= false + return iarr +end + function BangBang.setindex!!(iarr::IndexArray, value, optic::IndexLens) if _has_colon(optic) - # TODO(mhauru) This could be implemented, by getting size information from `value`. + # TODO(mhauru) This could be implemented by getting size information from `value`. + # However, the corresponding getindex is more fundamentally ill-defined. throw(ArgumentError("Indexing with colons is not supported")) end inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end iarr = if checkbounds(Bool, iarr.mask, inds...) iarr else @@ -80,10 +98,13 @@ function Base.getindex(iarr::IndexArray, optic::IndexLens) if _has_colon(optic) throw(ArgumentError("Indexing with colons is not supported")) end + inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end if !haskey(iarr, optic) - throw(BoundsError("No value set at indices $(optic)")) + throw(BoundsError(iarr, inds)) end - inds = optic.indices return getindex(iarr.data, inds...) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index cfa4d9d3c..85b824ffc 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1,6 +1,6 @@ module VarNamedTupleTests -using Test: @inferred, @testset, @test +using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: @varname, VarNamedTuple using BangBang: setindex!! @@ -11,9 +11,11 @@ using BangBang: setindex!! vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] @@ -46,17 +48,37 @@ using BangBang: setindex!! vec = fill(1.0, 4) vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) vec = fill(2.0, 4) vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) arr = fill(2.0, (4, 2)) - vn = @varname(k.l[2:5, 3, 1:2, 10]) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + + # Not enough, or too many, indices. + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) vnt = @inferred(setindex!!(vnt, arr, vn)) @test @inferred(getindex(vnt, vn)) == arr - @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 10]))) == fill(2.0, 2) + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) From 8aa217851e3ffee1fccafd8cdd4e8dbe9680ef4c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 16:19:39 +0000 Subject: [PATCH 17/18] Rename IndexArray to PartialArray --- src/varnamedtuple.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 10d7ea289..1e800cb35 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -24,24 +24,24 @@ struct IndexDict{T<:Function,Keys,Values} make_leaf::T end -struct IndexArray{T<:Function,ElType,numdims} +struct PartialArray{T<:Function,ElType,numdims} data::Array{ElType,numdims} mask::Array{Bool,numdims} make_leaf::T end -function IndexArray(eltype, num_dims, make_leaf) +function PartialArray(eltype, num_dims, make_leaf) dims = ntuple(_ -> 0, num_dims) data = Array{eltype,num_dims}(undef, dims) mask = fill(false, dims) - return IndexArray(data, mask, make_leaf) + return PartialArray(data, mask, make_leaf) end _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) _length_needed(::Colon) = 0 -function _resize_indexarray(iarr::IndexArray, inds) +function _resize_partialarray(iarr::PartialArray, inds) new_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 @@ -56,11 +56,11 @@ function _resize_indexarray(iarr::IndexArray, inds) @inbounds new_data[i] = iarr.data[i] end end - return IndexArray(new_data, new_mask, iarr.make_leaf) + return PartialArray(new_data, new_mask, iarr.make_leaf) end # The below implements the same functionality as above, but more performantly for 1D arrays. -function _resize_indexarray(iarr::IndexArray{T,Eltype,1}, (ind,)) where {T,Eltype} +function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} # Resize arrays to accommodate new indices. old_size = size(iarr.data, 1) new_size = max(old_size, _length_needed(ind)) @@ -70,7 +70,7 @@ function _resize_indexarray(iarr::IndexArray{T,Eltype,1}, (ind,)) where {T,Eltyp return iarr end -function BangBang.setindex!!(iarr::IndexArray, value, optic::IndexLens) +function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) if _has_colon(optic) # TODO(mhauru) This could be implemented by getting size information from `value`. # However, the corresponding getindex is more fundamentally ill-defined. @@ -83,7 +83,7 @@ function BangBang.setindex!!(iarr::IndexArray, value, optic::IndexLens) iarr = if checkbounds(Bool, iarr.mask, inds...) iarr else - _resize_indexarray(iarr, inds) + _resize_partialarray(iarr, inds) end new_data = setindex!!(iarr.data, value, inds...) if _is_multiindex(optic) @@ -91,10 +91,10 @@ function BangBang.setindex!!(iarr::IndexArray, value, optic::IndexLens) else iarr.mask[inds...] = true end - return IndexArray(new_data, iarr.mask, iarr.make_leaf) + return PartialArray(new_data, iarr.mask, iarr.make_leaf) end -function Base.getindex(iarr::IndexArray, optic::IndexLens) +function Base.getindex(iarr::PartialArray, optic::IndexLens) if _has_colon(optic) throw(ArgumentError("Indexing with colons is not supported")) end @@ -108,7 +108,7 @@ function Base.getindex(iarr::IndexArray, optic::IndexLens) return getindex(iarr.data, inds...) end -function Base.haskey(iarr::IndexArray, optic::IndexLens) +function Base.haskey(iarr::PartialArray, optic::IndexLens) if _has_colon(optic) throw(ArgumentError("Indexing with colons is not supported")) end @@ -132,7 +132,7 @@ function make_leaf_array(value, optic::IndexLens{T}) where {T} # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. et = _is_multiindex(optic) ? eltype(value) : typeof(value) - iarr = IndexArray(et, num_inds, make_leaf_array) + iarr = PartialArray(et, num_inds, make_leaf_array) return setindex!!(iarr, value, optic) end @@ -176,7 +176,7 @@ function Base.getindex(vnt::VarNamedTuple, name::VarName) return getindex(vnt, varname_to_lens(name)) end function Base.getindex( - x::Union{VarNamedTuple,IndexDict,IndexArray}, optic::ComposedFunction + x::Union{VarNamedTuple,IndexDict,PartialArray}, optic::ComposedFunction ) subdata = getindex(x, optic.inner) return getindex(subdata, optic.outer) @@ -216,7 +216,7 @@ function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) end function BangBang.setindex!!( - vnt::Union{VarNamedTuple,IndexDict,IndexArray}, value, optic::ComposedFunction + vnt::Union{VarNamedTuple,IndexDict,PartialArray}, value, optic::ComposedFunction ) sub = if haskey(vnt, optic.inner) BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) From 10c23ef56e6138f6a385a5bd2dcd2bae207b1206 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 16:54:00 +0000 Subject: [PATCH 18/18] Grow the sizes of PartialArrays in jumps --- src/varnamedtuple.jl | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 1e800cb35..448ae4636 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,6 +8,9 @@ using DynamicPPL: _compose_no_identity export VarNamedTuple +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + _has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) function _is_multiindex(::IndexLens{T}) where {T} @@ -31,7 +34,7 @@ struct PartialArray{T<:Function,ElType,numdims} end function PartialArray(eltype, num_dims, make_leaf) - dims = ntuple(_ -> 0, num_dims) + dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{eltype,num_dims}(undef, dims) mask = fill(false, dims) return PartialArray(data, mask, make_leaf) @@ -41,8 +44,19 @@ _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) _length_needed(::Colon) = 0 +"""Take the minimum size that a dimension of a PartialArray needs to be, and return the size +we choose it to be. This size will be the smallest possible power of +PARTIAL_ARRAY_DIM_GROWTH_FACTOR. Growing PartialArrays in big jumps like this helps reduce +data copying, as resizes aren't needed as often. +""" +function _partial_array_dim_size(min_dim) + factor = PARTIAL_ARRAY_DIM_GROWTH_FACTOR + return factor^(Int(ceil(log(factor, min_dim)))) +end + function _resize_partialarray(iarr::PartialArray, inds) - new_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + min_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + new_sizes = map(_partial_array_dim_size, min_sizes) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) @@ -63,7 +77,8 @@ end function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} # Resize arrays to accommodate new indices. old_size = size(iarr.data, 1) - new_size = max(old_size, _length_needed(ind)) + min_size = max(old_size, _length_needed(ind)) + new_size = _partial_array_dim_size(min_size) resize!(iarr.data, new_size) resize!(iarr.mask, new_size) @inbounds iarr.mask[(old_size + 1):new_size] .= false