From eab71317d406a56fe06df7c8f944a4063e564112 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 17:08:31 +0000 Subject: [PATCH 1/4] Add VarNamedTuple, tests, and WIP docs --- docs/src/internals/varnamedtuple.md | 112 ++++++++++ src/varnamedtuple.jl | 310 ++++++++++++++++++++++++++++ test/varnamedtuple.jl | 89 ++++++++ 3 files changed, 511 insertions(+) create mode 100644 docs/src/internals/varnamedtuple.md create mode 100644 src/varnamedtuple.jl create mode 100644 test/varnamedtuple.jl diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md new file mode 100644 index 000000000..9f7a84cdb --- /dev/null +++ b/docs/src/internals/varnamedtuple.md @@ -0,0 +1,112 @@ +# VarNamedTuple as the basis of VarInfo + +This document collects thoughts and ideas for how to unify our multitude of AbstractVarInfo types using a VarNamedTuple type. It may eventually turn into a draft design document, but for now it is more raw than that. + +## The current situation + +We currently have the following AbstractVarInfo types: + + - A: VarInfo with Metadata + - B: VarInfo with VarNamedVector + - C: VarInfo with NamedTuple, with values being Metadata + - D: VarInfo with NamedTuple, with values being VarNamedVector + - E: SimpleVarInfo with NamedTuples + - F: SimpleVarInfo with OrderedDict + +A and C are the classic ones, and the defaults. C wraps groups the Metadata objects by the lead Symbol of the VarName of a variable, e.g. `x` in `@varname(x.y[1].z)`, which allows different lead Symbols to have different element types and for the VarInfo to still be type stable. B and D were created to simplify A and C, give them a nicer interface, and make them deal better with changing variable sizes, but according to recent (Oct 2025) benchmarks are quite a lot slower, which needs work. + +E and F are entirely distinct in implementation from the others. E is simply a mapping from Symbols to values, with each VarName being converted to a single symbol, e.g. `Symbol("a[1]")`. F is a mapping from VarNames to values as an OrderedDict, with VarName as the key type. + +A-D carry within them values for variables, but also their bijectors/distributions, and store all values vectorised, using the bijectors to map to the original values. They also store for each variable a flag for whether the variable has been linked. E-F store only the raw values, and a global flag for the whole SimpleVarInfo for whether it's linked. The link transform itself is implicit. + +TODO: Write a better summary of pros and cons of each approach. + +## VarNamedTuple + +VarNamedTuple has been discussed as a possible data structure to generalise the structure used in VarInfo to achieve type stability, i.e. grouping VarNames by their lead Symbol. The same NamedTuple structure has been used elsewhere, too, e.g. in Turing.GibbsContext. The idea was to encapsulate this structure into its own type, reducing code duplication and making the design more robust and powerful. See https://github.com/TuringLang/DynamicPPL.jl/issues/900 for the discussion. + +An AbstractVarInfo type could be only one application of VarNamedTuple, but here I'll focus on it exclusively. If we can make VarNamedTuple work for an AbstractVarInfo, I bet we can make it work for other purposes (condition, fix, Gibbs) as well. + +Without going into full detail, here's @mhauru's current proposal for what it would look like. This proposal remains in constant flux as I develop the code. + +A VarNamedTuple is a mapping of VarNames to values. Values can be anything. In the case of using VarNamedTuple to implement an AbstractVarInfo, the values would be random samples for random variables. However, they could hold with them extra information. For instance, we might use a value that is a tuple of a vectorised value, a bijector, and a flag for whether the variable is linked. + +I sometimes shorten VarNamedTuple to VNT. + +Internally, a VarNamedTuple consists of nested NamedTuples. For instance, the mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as + +``` +(; x=1, y=(; z=2)) +``` + +(This is a slight simplification, really it would be nested VarNamedTuples rather than NamedTuples, but I omit this detail.) +This forms a tree, with each node being a NamedTuple, like so: + +``` + NT +x / \ y + 1 NT + \ z + 2 +``` + +Each `NT` marks a NamedTuple, and the labels on the edges its keys. Here the root node has the keys `x` and `y`. This is like with the type stable VarInfo in our current design, except with possibly more levels (our current one only has the root node). Each nested `PropertyLens`, i.e. each `.` in a VarName like `@varname(a.b.c.e)`, creates a new layer of the tree. + +For simplicity, at least for now, we ban any VarNames where an `IndexLens` precedes a `PropertyLens`. That is, we ban any VarNames like `@varname(a.b[1].c)`. Recall that VarNames allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). Thus the only allowed VarName types are `@varname(a.b.c.d)` and `@varname(a.b.c.d[i,j,k])`. + +This means that we can add levels to the NamedTuple tree until all `PropertyLenses` have been covered. The leaves of the tree are then of two kinds: They are either the raw value itself if the last lens of the VarName is an `identity`, or otherwise they are something that can be indexed with an `IndexLens`, such as an `Array`. + +To get a value from a VarNamedTuple is very simple: For `getindex(vnt::VNT, vn::VarName{S})` (`S` being the lead Symbol) you recurse into `getindex(vnt[S], unprefix(vn, S))`. If the last lens of `vn` is an `IndexLens`, we assume that the leaf of the NamedTuple tree we've reached contains something that can be indexed with it. + +Setting values in a VNT is equally simple if there are no `IndexLenses`: For `setindex!!(vnt::VNT, value::Any, vn::VarName)` one simply finds the leaf of the `vnt` tree corresponding to `vn` and sets its value to `value`. + +The tricky part is what to do when setting values with `IndexLenses`. There are three possible situations. Say one calls `setindex!!(vnt, 3.0, @varname(a.b[3]))`. + + 1. If `getindex(vnt, @varname(a.b))` is already a vector of length at least 3, this is easy: Just set the third element. + 2. If `getindex(vnt, @varname(a.b))` is a vector of length less than 3, what should we do? Do we error? Do we extend that vector? + 3. If `getindex(vnt, @varname(a.b))` isn't even set, what do we do? Say for instance that `vnt` is currently empty. We should set `vnt` to be something like `(; a=(; b=x))`, where `x` is such that `x[3] = 3.0`, but what exactly should `x` be? Is it a dictionary? A vector of length 3? If the latter, what are `x[2]` and `x[1]`? Or should this `setindex!!` call simply error? + +A note at this point: VarNamedTuples must always use `setindex!!`, the `!!` version that may or may not operate in place. The NamedTuples can't be modified in place, but the values at the leaves may be. Always using a `!!` function makes type stability easier, and makes structures like the type unstable old VarInfo with Metadata unnecessary: Any value can be set into any VarNamedTuple. The type parameters of the VNT will simply expand as necessary. + +To solve the problem of points 2. and 3. above I propose expanding the definition of VNT a bit. This will also help make VNT more flexible, which may help performance or allow more use cases. The modification is this: + +Unlike I said above, let's say that VNT isn't just nested NamedTuples with some values at the leaves. Let's say it also has a field called `make_leaf`. `make_leaf(value, lens)` is a function that takes any value, and a lens that is either `identity` or an `IndexLens`, and returns the value wrapped in some suitable struct that can be stored in the leaf of the NamedTuple tree. The values should always be such that `make_leaf(value, lens)[lens] == value`. + +Our earlier example of `VarNamedTuple(@varname(x) => 1, @varname(y.z) => 2; make_leaf=f)` would be stored as a tree like + +``` + --NT-- + x / \ y +f(1, identity) NT + \ z + f(2, identity) +``` + +The above, first draft of VNT which did not include `make_leaf` is equivalent to the trivial choice `make_leaf(value, lens) = lens === identity ? value : error("Don't know how to deal IndexLenses")`. The problems 2. and 3. above are "solved" by making it `make_leaf`'s problem to figure out what to do. For instance, `make_leaf` can always return a `Dict` that maps lenses to values. This is probably slow, but works for any lens. Or it can initialise a vector type, that can grow as needed when indexed into. + +The idea would be to use `make_leaf` to try out different ways of implementing a VarInfo, find a good default, and ,if necessary, leave the option for power users to customise behaviour. The first ones to implement would be + + - `make_leaf` that returns a Metadata object. This would be a direct replacement for type stable VarInfo that uses Metadata, except now with more nested levels of NamedTuple. + - `make_leaf` that returns an `OrderedDict`. This would be a direct replacement for SimpleVarInfo with OrderedDict. + +You may ask, have we simple gone from too many VarInfo types to too many `make_leaf` functions. Yes we have. But hopefully we have gained something in the process: + + - The leaf types can be simpler. They do not need to deal with VarNames any more, they only need to deal with `identity` lenses and `IndexLenses`. + - All AbstactVarInfos are as type stable as their leaf types allow. There is no more notion of an untyped VarInfo being converted to a typed one. + - Type stability is maintained even with nested `PropertyLenses` like `@varname(a.b)`, which happens a lot with submodels. + - Many functions that are currently implemented individually for each AbstactVarInfo type would now have a single implementation for the VarNamedTuple-based AbstactVarInfo type, reducing code duplication. I would also hope to get ride of most of the generated functions for in `varinfo.jl`. + +My guess is that the eventual One AbstractVarInfo To Rule Them All would have a `make_leaf` function that stores the raw values when the lens is an `identity`, and uses a flexible Vector, a lot like VarNamedVector, when the lens is an IndexLens. However, I could be wrong on that being the best option. Implementing and benchmarking is the only way to know. + +I think the two big questions are: + + - Will we run into some big, unanticipated blockers when we start to implement this. + - Will the nesting of NamedTuples cause performance regressions, if the compiler either chokes or gives up. + +I'll try to derisk these early on in this PR. + +## Questions / issues + + - People might really need IndexLenses in the middle of VarNames. The one place this comes up is submodels within a loop. I'm still inclined to keep designing without allowing for that, for now, but should keep in mind that that needs to be relaxed eventually. If it makes it easier, we can require that users explicitly tell us the size of any arrays for which this is done. + - When storing values for nested NamedTuples, the actual variable may be a struct. Do we need to be able to reconstruct the struct from the NamedTuple? If so, how do we do that? + - Do `Colon` indices cause any extra trouble for the leafnodes? diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl new file mode 100644 index 000000000..448ae4636 --- /dev/null +++ b/src/varnamedtuple.jl @@ -0,0 +1,310 @@ +# TODO(mhauru) This module should probably be moved to AbstractPPL. +module VarNamedTuples + +using AbstractPPL +using BangBang +using Accessors +using DynamicPPL: _compose_no_identity + +export VarNamedTuple + +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + +_has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) + +function _is_multiindex(::IndexLens{T}) where {T} + return any(x <: UnitRange || x <: Colon for x in T.parameters) +end + +struct VarNamedTuple{T<:Function,Names,Values} + data::NamedTuple{Names,Values} + make_leaf::T +end + +struct IndexDict{T<:Function,Keys,Values} + data::Dict{Keys,Values} + make_leaf::T +end + +struct PartialArray{T<:Function,ElType,numdims} + data::Array{ElType,numdims} + mask::Array{Bool,numdims} + make_leaf::T +end + +function PartialArray(eltype, num_dims, make_leaf) + dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + data = Array{eltype,num_dims}(undef, dims) + mask = fill(false, dims) + return PartialArray(data, mask, make_leaf) +end + +_length_needed(i::Integer) = i +_length_needed(r::UnitRange) = last(r) +_length_needed(::Colon) = 0 + +"""Take the minimum size that a dimension of a PartialArray needs to be, and return the size +we choose it to be. This size will be the smallest possible power of +PARTIAL_ARRAY_DIM_GROWTH_FACTOR. Growing PartialArrays in big jumps like this helps reduce +data copying, as resizes aren't needed as often. +""" +function _partial_array_dim_size(min_dim) + factor = PARTIAL_ARRAY_DIM_GROWTH_FACTOR + return factor^(Int(ceil(log(factor, min_dim)))) +end + +function _resize_partialarray(iarr::PartialArray, inds) + min_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + new_sizes = map(_partial_array_dim_size, min_sizes) + # Generic multidimensional Arrays can not be resized, so we need to make a new one. + # See https://github.com/JuliaLang/julia/issues/37900 + new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) + new_mask = fill(false, new_sizes) + # Note that we have to use CartesianIndices instead of eachindex, because the latter + # may use a linear index that does not match between the old and the new arrays. + for i in CartesianIndices(iarr.data) + mask_val = iarr.mask[i] + @inbounds new_mask[i] = mask_val + if mask_val + @inbounds new_data[i] = iarr.data[i] + end + end + return PartialArray(new_data, new_mask, iarr.make_leaf) +end + +# The below implements the same functionality as above, but more performantly for 1D arrays. +function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} + # Resize arrays to accommodate new indices. + old_size = size(iarr.data, 1) + min_size = max(old_size, _length_needed(ind)) + new_size = _partial_array_dim_size(min_size) + resize!(iarr.data, new_size) + resize!(iarr.mask, new_size) + @inbounds iarr.mask[(old_size + 1):new_size] .= false + return iarr +end + +function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) + if _has_colon(optic) + # TODO(mhauru) This could be implemented by getting size information from `value`. + # However, the corresponding getindex is more fundamentally ill-defined. + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end + iarr = if checkbounds(Bool, iarr.mask, inds...) + iarr + else + _resize_partialarray(iarr, inds) + end + new_data = setindex!!(iarr.data, value, inds...) + if _is_multiindex(optic) + iarr.mask[inds...] .= true + else + iarr.mask[inds...] = true + end + return PartialArray(new_data, iarr.mask, iarr.make_leaf) +end + +function Base.getindex(iarr::PartialArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + if length(inds) != ndims(iarr.data) + throw(ArgumentError("Invalid index $(inds)")) + end + if !haskey(iarr, optic) + throw(BoundsError(iarr, inds)) + end + return getindex(iarr.data, inds...) +end + +function Base.haskey(iarr::PartialArray, optic::IndexLens) + if _has_colon(optic) + throw(ArgumentError("Indexing with colons is not supported")) + end + inds = optic.indices + return checkbounds(Bool, iarr.mask, inds...) && + all(@inbounds(getindex(iarr.mask, inds...))) +end + +function make_leaf_array(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) +end +make_leaf_array(value, ::typeof(identity)) = value +function make_leaf_array(value, optic::ComposedFunction) + sub = make_leaf_array(value, optic.outer) + return make_leaf_array(sub, optic.inner) +end + +function make_leaf_array(value, optic::IndexLens{T}) where {T} + inds = optic.indices + num_inds = length(inds) + # Check if any of the indices are ranges or colons. If yes, value needs to be an + # AbstractArray. Otherwise it needs to be an individual value. + et = _is_multiindex(optic) ? eltype(value) : typeof(value) + iarr = PartialArray(et, num_inds, make_leaf_array) + return setindex!!(iarr, value, optic) +end + +function make_leaf_dict(value, ::PropertyLens{S}) where {S} + return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_dict) +end +make_leaf_dict(value, ::typeof(identity)) = value +function make_leaf_dict(value, optic::ComposedFunction) + sub = make_leaf_dict(value, optic.outer) + return make_leaf_dict(sub, optic.inner) +end +function make_leaf_dict(value, optic::IndexLens) + return IndexDict(Dict(optic.indices => value), make_leaf_dict) +end + +VarNamedTuple() = VarNamedTuple((;), make_leaf_array) + +function Base.show(io::IO, vnt::VarNamedTuple) + print(io, "(") + for (i, (name, value)) in enumerate(pairs(vnt.data)) + if i > 1 + print(io, ", ") + end + print(io, name, " -> ") + print(io, value) + end + return print(io, ")") +end + +function Base.show(io::IO, id::IndexDict) + return print(io, id.data) +end + +Base.getindex(vnt::VarNamedTuple, name::Symbol) = vnt.data[name] + +function varname_to_lens(name::VarName{S}) where {S} + return _compose_no_identity(getoptic(name), PropertyLens{S}()) +end + +function Base.getindex(vnt::VarNamedTuple, name::VarName) + return getindex(vnt, varname_to_lens(name)) +end +function Base.getindex( + x::Union{VarNamedTuple,IndexDict,PartialArray}, optic::ComposedFunction +) + subdata = getindex(x, optic.inner) + return getindex(subdata, optic.outer) +end +function Base.getindex(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} + return getindex(vnt.data, S) +end +function Base.getindex(id::IndexDict, optic::IndexLens) + return getindex(id.data, optic.indices) +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName) + return haskey(vnt, varname_to_lens(name)) +end + +Base.haskey(vnt::VarNamedTuple, ::typeof(identity)) = true + +function Base.haskey(vnt::VarNamedTuple, optic::ComposedFunction) + return haskey(vnt, optic.inner) && haskey(getindex(vnt, optic.inner), optic.outer) +end + +Base.haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) +Base.haskey(id::IndexDict, optic::IndexLens) = haskey(id.data, optic.indices) +Base.haskey(::VarNamedTuple, ::IndexLens) = false +Base.haskey(::IndexDict, ::PropertyLens) = false + +# TODO(mhauru) This is type piracy. +Base.getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) + +# TODO(mhauru) This is type piracy. +function BangBang.setindex!!(arr::AbstractArray, value, optic::IndexLens) + return BangBang.setindex!!(arr, value, optic.indices...) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, name::VarName) + return BangBang.setindex!!(vnt, value, varname_to_lens(name)) +end + +function BangBang.setindex!!( + vnt::Union{VarNamedTuple,IndexDict,PartialArray}, value, optic::ComposedFunction +) + sub = if haskey(vnt, optic.inner) + BangBang.setindex!!(getindex(vnt, optic.inner), value, optic.outer) + else + vnt.make_leaf(value, optic.outer) + end + return BangBang.setindex!!(vnt, sub, optic.inner) +end + +function BangBang.setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} + # I would like this to just read + # return VarNamedTuple(BangBang.setindex!!(vnt.data, value, S), vnt.make_leaf) + # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the + # below? + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((value,))), vnt.make_leaf) +end + +function BangBang.setindex!!(id::IndexDict, value, optic::IndexLens) + return IndexDict(setindex!!(id.data, value, optic.indices), id.make_leaf) +end + +function apply(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt.data, name.name) + throw(KeyError(repr(name))) + end + subdata = getindex(vnt, name) + new_subdata = func(subdata) + return BangBang.setindex!!(vnt, new_subdata, name) +end + +function Base.map(func, vnt::VarNamedTuple) + new_data = NamedTuple{keys(vnt.data)}(map(func, values(vnt.data))) + return VarNamedTuple(new_data, vnt.make_leaf) +end + +function Base.keys(vnt::VarNamedTuple) + result = () + for sym in keys(vnt.data) + subdata = vnt.data[sym] + if subdata isa VarNamedTuple + subkeys = keys(subdata) + result = ( + (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)..., result... + ) + else + result = (VarName{sym}(), result...) + end + subkeys = keys(vnt.data[sym]) + end + return result +end + +function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} + if !haskey(vnt.data, S) + return false + end + subdata = vnt.data[S] + return if Optic === typeof(identity) + true + elseif Optic <: IndexLens + try + AbstractPPL.getoptic(name)(subdata) + true + catch e + if e isa BoundsError || e isa KeyError + false + else + rethrow(e) + end + end + else + haskey(subdata, AbstractPPL.unprefix(name, VarName{S}())) + end +end + +end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl new file mode 100644 index 000000000..85b824ffc --- /dev/null +++ b/test/varnamedtuple.jl @@ -0,0 +1,89 @@ +module VarNamedTupleTests + +using Test: @inferred, @test, @test_throws, @testset +using DynamicPPL: @varname, VarNamedTuple +using BangBang: setindex!! + +@testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 + + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type stability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + + # Not enough, or too many, indices. + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws "Invalid index" setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) +end + +end From 0c7825bd9b80494459393ea6b7349885d8c2e29c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 11:58:22 +0000 Subject: [PATCH 2/4] Add comparisons and merge --- src/varnamedtuple.jl | 179 +++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 120 ++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 21 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 448ae4636..006e8f0d5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -4,16 +4,18 @@ module VarNamedTuples using AbstractPPL using BangBang using Accessors -using DynamicPPL: _compose_no_identity +using ..DynamicPPL: _compose_no_identity export VarNamedTuple """The factor by which we increase the dimensions of PartialArrays when resizing them.""" const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 -_has_colon(::IndexLens{T}) where {T} = any(x <: Colon for x in T.parameters) +const INDEX_TYPES = Union{Integer,UnitRange,Colon} -function _is_multiindex(::IndexLens{T}) where {T} +_has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) + +function _is_multiindex(::T) where {T<:Tuple} return any(x <: UnitRange || x <: Colon for x in T.parameters) end @@ -22,6 +24,12 @@ struct VarNamedTuple{T<:Function,Names,Values} make_leaf::T end +# TODO(mhauru) Since I define this, should I also define `isequal` and `hash`? Same for +# PartialArrays. +function Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + return vnt1.make_leaf === vnt2.make_leaf && vnt1.data == vnt2.data +end + struct IndexDict{T<:Function,Keys,Values} data::Dict{Keys,Values} make_leaf::T @@ -33,13 +41,44 @@ struct PartialArray{T<:Function,ElType,numdims} make_leaf::T end -function PartialArray(eltype, num_dims, make_leaf) +function PartialArray(eltype, num_dims, make_leaf=make_leaf_array) dims = ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) data = Array{eltype,num_dims}(undef, dims) mask = fill(false, dims) return PartialArray(data, mask, make_leaf) end +Base.ndims(iarr::PartialArray) = ndims(iarr.data) + +# We deliberately don't define Base.size for PartialArray, because it is ill-defined. +# The size of the .data field is an implementation detail. +_internal_size(iarr::PartialArray, args...) = size(iarr.data, args...) + +function Base.copy(pa::PartialArray) + return PartialArray(copy(pa.data), copy(pa.mask), pa.make_leaf) +end + +function Base.:(==)(pa1::PartialArray, pa2::PartialArray) + if (pa1.make_leaf !== pa2.make_leaf) || (ndims(pa1) != ndims(pa2)) + return false + end + size1 = _internal_size(pa1) + size2 = _internal_size(pa2) + # TODO(mhauru) This could be optimised, but not sure it's worth it. + merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + for i in CartesianIndices(merge_size) + m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false + m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false + if m1 != m2 + return false + end + if m1 && (pa1.data[i] != pa2.data[i]) + return false + end + end + return true +end + _length_needed(i::Integer) = i _length_needed(r::UnitRange) = last(r) _length_needed(::Colon) = 0 @@ -55,11 +94,13 @@ function _partial_array_dim_size(min_dim) end function _resize_partialarray(iarr::PartialArray, inds) - min_sizes = ntuple(i -> max(size(iarr.data, i), _length_needed(inds[i])), length(inds)) + min_sizes = ntuple( + i -> max(_internal_size(iarr, i), _length_needed(inds[i])), length(inds) + ) new_sizes = map(_partial_array_dim_size, min_sizes) # Generic multidimensional Arrays can not be resized, so we need to make a new one. # See https://github.com/JuliaLang/julia/issues/37900 - new_data = Array{eltype(iarr.data),ndims(iarr.data)}(undef, new_sizes) + new_data = Array{eltype(iarr.data),ndims(iarr)}(undef, new_sizes) new_mask = fill(false, new_sizes) # Note that we have to use CartesianIndices instead of eachindex, because the latter # may use a linear index that does not match between the old and the new arrays. @@ -76,7 +117,7 @@ end # The below implements the same functionality as above, but more performantly for 1D arrays. function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,Eltype} # Resize arrays to accommodate new indices. - old_size = size(iarr.data, 1) + old_size = _internal_size(iarr, 1) min_size = max(old_size, _length_needed(ind)) new_size = _partial_array_dim_size(min_size) resize!(iarr.data, new_size) @@ -85,14 +126,19 @@ function _resize_partialarray(iarr::PartialArray{T,Eltype,1}, (ind,)) where {T,E return iarr end -function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) - if _has_colon(optic) +function BangBang.setindex!!(pa::PartialArray, value, optic::IndexLens) + return BangBang.setindex!!(pa, value, optic.indices...) +end +Base.getindex(pa::PartialArray, optic::IndexLens) = Base.getindex(pa, optic.indices...) +Base.haskey(pa::PartialArray, optic::IndexLens) = Base.haskey(pa, optic.indices) + +function BangBang.setindex!!(iarr::PartialArray, value, inds::Vararg{INDEX_TYPES}) + if _has_colon(inds) # TODO(mhauru) This could be implemented by getting size information from `value`. # However, the corresponding getindex is more fundamentally ill-defined. throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices - if length(inds) != ndims(iarr.data) + if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end iarr = if checkbounds(Bool, iarr.mask, inds...) @@ -101,7 +147,7 @@ function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) _resize_partialarray(iarr, inds) end new_data = setindex!!(iarr.data, value, inds...) - if _is_multiindex(optic) + if _is_multiindex(inds) iarr.mask[inds...] .= true else iarr.mask[inds...] = true @@ -109,29 +155,105 @@ function BangBang.setindex!!(iarr::PartialArray, value, optic::IndexLens) return PartialArray(new_data, iarr.mask, iarr.make_leaf) end -function Base.getindex(iarr::PartialArray, optic::IndexLens) - if _has_colon(optic) +function Base.getindex(iarr::PartialArray, inds::Vararg{INDEX_TYPES}) + if _has_colon(inds) throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices - if length(inds) != ndims(iarr.data) + if length(inds) != ndims(iarr) throw(ArgumentError("Invalid index $(inds)")) end - if !haskey(iarr, optic) + if !haskey(iarr, inds) throw(BoundsError(iarr, inds)) end return getindex(iarr.data, inds...) end -function Base.haskey(iarr::PartialArray, optic::IndexLens) - if _has_colon(optic) +function Base.haskey(iarr::PartialArray, inds) + if _has_colon(inds) throw(ArgumentError("Indexing with colons is not supported")) end - inds = optic.indices return checkbounds(Bool, iarr.mask, inds...) && all(@inbounds(getindex(iarr.mask, inds...))) end +Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) +_merge_recursive(_, x2) = x2 + +function _merge_element_recursive(x1::PartialArray, x2::PartialArray, ind::CartesianIndex) + m1 = x1.mask[ind] + m2 = x2.mask[ind] + return if m1 && m2 + _merge_recursive(x1.data[ind], x2.data[ind]) + elseif m2 + x2.data[ind] + else + x1.data[ind] + end +end + +# TODO(mhauru) Would this benefit from a specialised method for 1D PartialArrays? +function _merge_recursive(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + throw( + ArgumentError("Cannot merge PartialArrays with different number of dimensions") + ) + end + if pa1.make_leaf !== pa2.make_leaf + throw( + ArgumentError("Cannot merge PartialArrays with different make_leaf functions") + ) + end + num_dims = ndims(pa1) + merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) + result = if merge_size == _internal_size(pa2) + # Either pa2 is strictly bigger than pa1, or they are equal in size. + result = copy(pa2) + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + else + if merge_size == _internal_size(pa1) + # pa1 is bigger than pa2 + result = copy(pa1) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result = setindex!!( + result, _merge_element_recursive(result, pa2, i), Tuple(i)... + ) + end + end + result + else + # Neither is strictly bigger than the other. + et = promote_type(eltype(pa1), eltype(pa2)) + new_data = Array{et,num_dims}(undef, merge_size) + new_mask = fill(false, merge_size) + result = PartialArray(new_data, new_mask, pa2.make_leaf) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result.mask[i] = true + result.data[i] = pa2.data[i] + end + end + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + end + end + return result +end + function make_leaf_array(value, ::PropertyLens{S}) where {S} return VarNamedTuple(NamedTuple{(S,)}((value,)), make_leaf_array) end @@ -146,7 +268,7 @@ function make_leaf_array(value, optic::IndexLens{T}) where {T} num_inds = length(inds) # Check if any of the indices are ranges or colons. If yes, value needs to be an # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(optic) ? eltype(value) : typeof(value) + et = _is_multiindex(optic.indices) ? eltype(value) : typeof(value) iarr = PartialArray(et, num_inds, make_leaf_array) return setindex!!(iarr, value, optic) end @@ -307,4 +429,19 @@ function Base.haskey(vnt::VarNamedTuple, name::VarName{S,Optic}) where {S,Optic} end end +# TODO(mhauru) Check the performance of this, and make it into a generated function if +# necessary. +function _merge_recursive(vnt1::VarNamedTuple, vnt2::VarNamedTuple) + result_data = vnt1.data + for k in keys(vnt2.data) + val = if haskey(result_data, k) + _merge_recursive(result_data[k], vnt2.data[k]) + else + vnt2.data[k] + end + Accessors.@reset result_data[k] = val + end + return VarNamedTuple(result_data, vnt2.make_leaf) +end + end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 85b824ffc..f9864e7be 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -86,4 +86,124 @@ using BangBang: setindex!! @test !haskey(vnt, @varname(m[1])) end +@testset "equality" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + @test vnt1 != vnt2 + + vnt2 = setindex!!(vnt2, 1.0, @varname(a)) + @test vnt1 == vnt2 + + vnt1 = setindex!!(vnt1, [1, 2], @varname(b)) + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, [1, 3], @varname(b)) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, [1, 2], @varname(b)) + + # Try with index lenses too + vnt1 = setindex!!(vnt1, 2, @varname(c[2])) + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, 3, @varname(c[2])) + @test vnt1 != vnt2 + vnt2 = setindex!!(vnt2, 2, @varname(c[2])) + + vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2])) + vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2])) + @test vnt1 == vnt2 + + vnt2 = setindex!!(vnt2, :b, @varname(d.e[2])) + @test vnt1 != vnt2 +end + +@testset "merge" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + # TODO(mhauru) Wrap this merge in @inferred, likewise other merges where it makes sense. + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt2 = setindex!!(vnt2, 2.0, @varname(b)) + vnt1 = setindex!!(vnt1, 1, @varname(c)) + vnt2 = setindex!!(vnt2, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) + expected_merge = setindex!!(expected_merge, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + vnt1 = setindex!!(vnt1, [1], @varname(d.a)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) + vnt1 = setindex!!(vnt1, [1], @varname(d.c)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) + vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) + vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) + expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) + @test merge(vnt1, vnt2) == expected_merge + + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + expected_merge = setindex!!( + expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) + ) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) + expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + @test merge(vnt1, vnt2) == expected_merge + + # PartialArrays with different sizes. + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[1025])) + vnt2 = setindex!!(vnt2, 2, @varname(a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(a[2])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) + @test merge(vnt2, vnt1) == expected_merge_21 + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[1025, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1025])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1025])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) + @test merge(vnt2, vnt1) == expected_merge_21 +end + end From 15d5a8a97795de35390706e858cf60a48cb17b76 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 12:09:39 +0000 Subject: [PATCH 3/4] Start using VNT in FastLDF --- src/DynamicPPL.jl | 2 ++ src/contexts/init.jl | 16 +++------ src/fasteval.jl | 81 +++++++++++++------------------------------ test/fasteval.jl | 8 ++--- test/varnamedtuple.jl | 12 +++---- 5 files changed, 39 insertions(+), 80 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..5f32a8b66 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -178,6 +178,8 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") +include("varnamedtuple.jl") +using .VarNamedTuples: VarNamedTuple include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..a0ad92fe3 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -215,8 +215,7 @@ end """ VectorWithRanges( - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, vect::AbstractVector{<:Real}, ) @@ -231,20 +230,13 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} - # This NamedTuple stores the ranges for identity VarNames - iden_varname_ranges::N - # This Dict stores the ranges for all other VarNames - varname_ranges::Dict{VarName,RangeAndLinked} +struct VectorWithRanges{VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} + # Ranges for all VarNames + varname_ranges::VNT # The full parameter vector which we index into to get variable values vect::T end -function _get_range_and_linked( - vr::VectorWithRanges, ::VarName{sym,typeof(identity)} -) where {sym} - return vr.iden_varname_ranges[sym] -end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) return vr.varname_ranges[vn] end diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..b82180dca 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -13,6 +13,7 @@ using DynamicPPL: RangeAndLinked, VectorWithRanges, Metadata, + VarNamedTuple, VarNamedVector, default_accumulators, float_type_with_fallback, @@ -140,14 +141,13 @@ struct FastLDF{ M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, - N<:NamedTuple, + VNT<:VarNamedTuple, ADP<:Union{Nothing,DI.GradientPrep}, } model::M adtype::AD _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} + _varname_ranges::VNT _adprep::ADP _dim::Int @@ -159,7 +159,7 @@ struct FastLDF{ ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + all_ranges = get_ranges_and_linked(varinfo) x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -169,19 +169,17 @@ struct FastLDF{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), - adtype, - x, + FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x ) end return new{ typeof(model), typeof(adtype), typeof(getlogdensity), - typeof(all_iden_ranges), + typeof(all_ranges), typeof(prep), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim + model, adtype, getlogdensity, all_ranges, prep, dim ) end end @@ -206,18 +204,15 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct FastLogDensityAt{M<:Model,F<:Function,VNT<:VarNamedTuple} model::M getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} + varname_ranges::VNT end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = InitContext( Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), + InitFromParams(VectorWithRanges(f.varname_ranges, params), nothing), ) model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) @@ -242,20 +237,14 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - )( - params - ) + return FastLogDensityAt(fldf.model, fldf._getlogdensity, fldf._varname_ranges)(params) end function LogDensityProblems.logdensity_and_gradient( fldf::FastLDF, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges - ), + FastLogDensityAt(fldf.model, fldf._getlogdensity, fldf._varname_ranges), fldf._adprep, fldf.adtype, params, @@ -291,62 +280,42 @@ end Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter representation, along with whether each variable is linked or unlinked. -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +This function returns a VarNamedTuple mapping all VarNames to their corresponding +`RangeAndLinked`. """ function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = 1 for sym in syms md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) + this_md_others, offset = get_ranges_and_linked_metadata(md, offset) all_ranges = merge(all_ranges, this_md_others) end - return all_iden_ranges, all_ranges + return all_ranges end function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others + all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_ranges end function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() + all_ranges = VarNamedTuple() offset = start_offset for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end + all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) offset += length(range) end - return all_iden_ranges, all_ranges, offset + return all_ranges, offset end diff --git a/test/fasteval.jl b/test/fasteval.jl index db2333711..2ad50ed26 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -36,17 +36,13 @@ end else unlinked_vi end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) params = [x for x in vi[:]] # Iterate over all variables for vn in keys(vi) # Check that `getindex_internal` returns the same thing as using the ranges # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end + range_with_linked = ranges[vn] @test params[range_with_linked.range] == DynamicPPL.getindex_internal(vi, vn) # Check that the link status is correct diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index f9864e7be..99f528175 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -180,11 +180,11 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() vnt1 = setindex!!(vnt1, 1, @varname(a[1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[1025])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257])) vnt2 = setindex!!(vnt2, 2, @varname(a[1])) vnt2 = setindex!!(vnt2, 2, @varname(a[2])) expected_merge_12 = VarNamedTuple() - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) @test merge(vnt1, vnt2) == expected_merge_12 @@ -194,13 +194,13 @@ end vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) - vnt1 = setindex!!(vnt1, 1, @varname(a[1025, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1025])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) expected_merge_12 = VarNamedTuple() expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) - expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[1025, 1])) - expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1025])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) @test merge(vnt1, vnt2) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) @test merge(vnt2, vnt1) == expected_merge_21 From 871eb9fd1216f392460462d4c84d8a38ca89da05 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 20 Nov 2025 12:41:55 +0000 Subject: [PATCH 4/4] Move _compose_no_identity to utils.jl --- src/utils.jl | 16 ++++++++++++++++ src/varnamedvector.jl | 16 ---------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 75fb805dc..fe2879182 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -949,3 +949,19 @@ end Return `typeof(x)` stripped of its type parameters. """ basetypeof(x::T) where {T} = Base.typename(T).wrapper + +# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if +# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only +# the latter one would be kept. +""" + _compose_no_identity(f, g) + +Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. + +This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type +conflicts. +""" +_compose_no_identity(f, g) = f ∘ g +_compose_no_identity(::typeof(identity), g) = g +_compose_no_identity(f, ::typeof(identity)) = f +_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..e5d2f2c2e 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1355,22 +1355,6 @@ function nextrange(vnv::VarNamedVector, x) return (offset + 1):(offset + length(x)) end -# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if -# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only -# the latter one would be kept. -""" - _compose_no_identity(f, g) - -Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. - -This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type -conflicts. -""" -_compose_no_identity(f, g) = f ∘ g -_compose_no_identity(::typeof(identity), g) = g -_compose_no_identity(f, ::typeof(identity)) = f -_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity - """ shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int)