From 7c681fa0d2ecace39bce15a2d271f52578036a41 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 5 Jul 2025 20:58:06 +0100 Subject: [PATCH 1/4] Normalise optics when constructing VarName; remove deprecated methods --- HISTORY.md | 14 +++++ Project.toml | 2 +- src/AbstractPPL.jl | 1 - src/deprecations.jl | 2 - src/varname.jl | 140 +++++++++++++++++++++++-------------------- test/deprecations.jl | 4 -- test/runtests.jl | 9 --- test/varname.jl | 69 +++++++++++++++++---- 8 files changed, 148 insertions(+), 93 deletions(-) create mode 100644 HISTORY.md delete mode 100644 src/deprecations.jl delete mode 100644 test/deprecations.jl diff --git a/HISTORY.md b/HISTORY.md new file mode 100644 index 0000000..a91b63b --- /dev/null +++ b/HISTORY.md @@ -0,0 +1,14 @@ +## 0.12.0 + +### VarName constructors + +Removed the constructors `VarName(vn, optic)` (this wasn't deprecated, but was dangerous as it would silently discard the existing optic in `vn`), and `VarName(vn, ::Tuple)` (which was deprecated). + +### Optic normalisation + +In the inner constructor of a VarName, its optic is now normalised to ensure that the associativity of ComposedFunction is always the same, and that compositions with identity are removed. +This helps to prevent subtle bugs where VarNames with semantically equal optics are not considered equal. + +## 0.11.0 + +Added the `prefix(vn::VarName, vn_prefix::VarName)` and `unprefix(vn::VarName, vn_prefix::VarName)` functions. diff --git a/Project.toml b/Project.toml index 42f84f2..ecccd56 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.11.0" +version = "0.12.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index d0df706..40da231 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -29,6 +29,5 @@ include("varname.jl") include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") -include("deprecations.jl") end # module diff --git a/src/deprecations.jl b/src/deprecations.jl deleted file mode 100644 index 24901bb..0000000 --- a/src/deprecations.jl +++ /dev/null @@ -1,2 +0,0 @@ -@deprecate VarName(sym::Symbol) VarName{sym}() -@deprecate VarName(sym::Symbol, indexing::Tuple) VarName{sym}(indexing) diff --git a/src/varname.jl b/src/varname.jl index 38d558c..9c5e1e3 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,8 +1,8 @@ using Accessors -using Accessors: ComposedOptic, PropertyLens, IndexLens, DynamicIndexLens +using Accessors: PropertyLens, IndexLens, DynamicIndexLens using JSON: JSON -const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedOptic} +const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedFunction} """ VarName{sym}(optic=identity) @@ -31,10 +31,11 @@ julia> @varname x[:, 1][1+1] x[:, 1][2] ``` """ -struct VarName{sym,T} +struct VarName{sym,T<:ALLOWED_OPTICS} optic::T function VarName{sym}(optic=identity) where {sym} + optic = normalise(optic) if !is_static_optic(typeof(optic)) throw( ArgumentError( @@ -53,42 +54,68 @@ Return `true` if `l` is one or a composition of `identity`, `PropertyLens`, and one or a composition of `DynamicIndexLens`; and undefined otherwise. """ is_static_optic(::Type{<:Union{typeof(identity),PropertyLens,IndexLens}}) = true -function is_static_optic(::Type{ComposedOptic{LO,LI}}) where {LO,LI} +function is_static_optic(::Type{ComposedFunction{LO,LI}}) where {LO,LI} return is_static_optic(LO) && is_static_optic(LI) end is_static_optic(::Type{<:DynamicIndexLens}) = false -# A bit of backwards compatibility. -VarName{sym}(indexing::Tuple) where {sym} = VarName{sym}(tupleindex2optic(indexing)) - """ - VarName(vn::VarName, optic) - VarName(vn::VarName, indexing::Tuple) + normalise(optic) -Return a copy of `vn` with a new index `optic`/`indexing`. +Enforce that compositions of optics are always nested in the same way, in that +a ComposedFunction never has a ComposedFunction as its inner lens. Thus, for +example, ```jldoctest; setup=:(using Accessors) -julia> VarName(@varname(x[1][2:3]), Accessors.IndexLens((2,))) -x[2] +julia> op1 = ((@o _.c) ∘ (@o _.b)) ∘ (@o _.a) +(@o _.a.b.c) -julia> VarName(@varname(x[1][2:3]), ((2,),)) -x[2] +julia> op2 = (@o _.c) ∘ ((@o _.b) ∘ (@o _.a)) +(@o _.c) ∘ ((@o _.a.b)) -julia> VarName(@varname(x[1][2:3])) -x +julia> op1 == op2 +false + +julia> AbstractPPL.normalise(op1) == AbstractPPL.normalise(op2) == @o _.a.b.c +true ``` -""" -VarName(vn::VarName, optic=identity) = VarName{getsym(vn)}(optic) -function VarName(vn::VarName, indexing::Tuple) - return VarName{getsym(vn)}(tupleindex2optic(indexing)) -end +This function also removes redundant `identity` optics from ComposedFunctions: + +```jldoctest; setup=:(using Accessors) +julia> op3 = ((@o _.b) ∘ identity) ∘ (@o _.a) +(@o identity(_.a).b) -tupleindex2optic(indexing::Tuple{}) = identity -tupleindex2optic(indexing::Tuple{<:Tuple}) = IndexLens(first(indexing)) # TODO: rest? -function tupleindex2optic(indexing::Tuple) - return IndexLens(first(indexing)) ∘ tupleindex2optic(indexing[2:end]) +julia> op4 = (@o _.b) ∘ (identity ∘ (@o _.a)) +(@o _.b) ∘ ((@o identity(_.a))) + +julia> AbstractPPL.normalise(op3) == AbstractPPL.normalise(op4) == @o _.a.b +true +``` +""" +function normalise(o::ComposedFunction{Outer,<:ComposedFunction}) where {Outer} + # `o` is currently (outer ∘ (inner_outer ∘ inner_inner)). + # We want to change this to: + # o = (outer ∘ inner_outer) ∘ inner_inner + inner_inner = o.inner.inner + inner_outer = o.inner.outer + # Recursively call normalise because inner_inner could itself be a + # ComposedFunction + return normalise((o.outer ∘ inner_outer) ∘ inner_inner) +end +function normalise(o::ComposedFunction{Outer,typeof(identity)} where {Outer}) + # strip outer identity + return normalise(o.outer) +end +function normalise(o::ComposedFunction{typeof(identity),Inner} where {Inner}) + # strip inner identity + return normalise(o.inner) end +normalise(o::ComposedFunction) = normalise(o.outer) ∘ o.inner +normalise(o::ALLOWED_OPTICS) = o +# These two methods are needed to avoid method ambiguity. +normalise(o::ComposedFunction{typeof(identity),<:ComposedFunction}) = normalise(o.inner) +normalise(::ComposedFunction{typeof(identity),typeof(identity)}) = identity """ getsym(vn::VarName) @@ -105,7 +132,7 @@ julia> getsym(@varname(y)) :y ``` """ -getsym(vn::VarName{sym}) where {sym} = sym +getsym(::VarName{sym}) where {sym} = sym """ getoptic(vn::VarName) @@ -154,15 +181,8 @@ function Accessors.set(obj, vn::VarName{sym}, value) where {sym} end # Allow compositions with optic. -function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym,<:ALLOWED_OPTICS}) where {sym} - vn_optic = getoptic(vn) - if vn_optic == identity - return VarName{sym}(optic) - elseif optic == identity - return vn - else - return VarName{sym}(optic ∘ vn_optic) - end +function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym}) where {sym} + return VarName{sym}(optic ∘ getoptic(vn)) end Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) @@ -299,17 +319,17 @@ subsumes(::typeof(identity), ::typeof(identity)) = true subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false -function subsumes(t::ComposedOptic, u::ComposedOptic) +function subsumes(t::ComposedFunction, u::ComposedFunction) return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner) end # If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a # leaf of the "lens-tree". -subsumes(t::ComposedOptic, u::PropertyLens) = false +subsumes(t::ComposedFunction, u::PropertyLens) = false # Here we need to check if `u.inner` (i.e. the next lens to be applied from `u`) is # subsumed by `t`, since this would mean that the rest of the composition is also subsumed # by `t`. -subsumes(t::PropertyLens, u::ComposedOptic) = subsumes(t, u.inner) +subsumes(t::PropertyLens, u::ComposedFunction) = subsumes(t, u.inner) # For `PropertyLens` either they have the same `name` and thus they are indeed the same. subsumes(t::PropertyLens{name}, u::PropertyLens{name}) where {name} = true @@ -321,8 +341,8 @@ subsumes(t::PropertyLens, u::PropertyLens) = false # FIXME: Does not correctly handle cases such as `subsumes(x, x[:])` # (but neither did old implementation). function subsumes( - t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}, - u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}, + t::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, + u::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, ) return subsumes_indices(t, u) end @@ -415,7 +435,7 @@ The result is compatible with [`subsumes_indices`](@ref) for `Tuple` input. """ combine_indices(optic::ALLOWED_OPTICS) = (), optic combine_indices(optic::IndexLens) = (optic.indices,), nothing -function combine_indices(optic::ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}) +function combine_indices(optic::ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}) indices, next = combine_indices(optic.outer) return (optic.inner.indices, indices...), next end @@ -505,9 +525,9 @@ concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x) function concretize(I::IndexLens, x) return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices))) end -function concretize(I::ComposedOptic, x) +function concretize(I::ComposedFunction, x) x_inner = I.inner(x) # TODO: get view here - return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x)) + return ComposedFunction(concretize(I.outer, x_inner), concretize(I.inner, x)) end """ @@ -533,7 +553,7 @@ julia> # The underlying value is concretized, though: ConcretizedSlice(Base.OneTo(100)) ``` """ -concretize(vn::VarName, x) = VarName(vn, concretize(getoptic(vn), x)) +concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x)) """ @varname(expr, concretize=false) @@ -872,7 +892,7 @@ function optic_to_dict(::PropertyLens{sym}) where {sym} return Dict("type" => "property", "field" => String(sym)) end optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices)) -function optic_to_dict(c::ComposedOptic) +function optic_to_dict(c::ComposedFunction) return Dict( "type" => "composed", "outer" => optic_to_dict(c.outer), @@ -1036,14 +1056,12 @@ ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarN function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} return VarName{sym}() end -function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} - return optic_to_vn(o.outer) -end function optic_to_vn( o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}} ) where {Outer,sym} return VarName{sym}(o.outer) end +optic_to_vn(o::Base.ComposedFunction) = optic_to_vn(normalise(o)) function optic_to_vn(@nospecialize(o)) msg = "optic_to_vn: could not convert optic `$o` to a VarName" throw(ArgumentError(msg)) @@ -1051,17 +1069,21 @@ end unprefix_optic(o, ::typeof(identity)) = o # Base case function unprefix_optic(optic, optic_prefix) + # Technically `unprefix_optic` only receives optics that were part of + # VarNames, so the optics should already be normalised (in the inner + # constructor of the VarName). However I guess it doesn't hurt to do it + # again to be safe. + optic = normalise(optic) + optic_prefix = normalise(optic_prefix) # strip one layer of the optic and check for equality - inner = _inner(_strip_identity(optic)) - inner_prefix = _inner(_strip_identity(optic_prefix)) + inner = _inner(optic) + inner_prefix = _inner(optic_prefix) if inner != inner_prefix msg = "could not remove prefix $(optic_prefix) from optic $(optic)" throw(ArgumentError(msg)) end # recurse - return unprefix_optic( - _outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix)) - ) + return unprefix_optic(_outer(optic), _outer(optic_prefix)) end """ @@ -1115,16 +1137,6 @@ y[1].x.a function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix} optic_vn = getoptic(vn) optic_prefix = getoptic(prefix) - # Special case `identity` to avoid having ComposedFunctions with identity - if optic_vn == identity - new_inner_optic_vn = PropertyLens{sym_vn}() - else - new_inner_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() - end - if optic_prefix == identity - new_optic_vn = new_inner_optic_vn - else - new_optic_vn = new_inner_optic_vn ∘ optic_prefix - end + new_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() ∘ optic_prefix return VarName{sym_prefix}(new_optic_vn) end diff --git a/test/deprecations.jl b/test/deprecations.jl deleted file mode 100644 index dffbe4d..0000000 --- a/test/deprecations.jl +++ /dev/null @@ -1,4 +0,0 @@ -@testset "deprecations.jl" begin - @test (@test_deprecated VarName(:x)) == VarName{:x}() - @test (@test_deprecated VarName(:x, ((1,), (:, 2)))) == VarName{:x}(((1,), (:, 2))) -end diff --git a/test/runtests.jl b/test/runtests.jl index a11a829..8be65eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,3 @@ -# Activate test environment on older Julia versions -if VERSION < v"1.2" - using Pkg: Pkg - Pkg.activate(@__DIR__) - Pkg.develop(Pkg.PackageSpec(; path=dirname(@__DIR__))) - Pkg.instantiate() -end - using AbstractPPL using Documenter using Test @@ -14,7 +6,6 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractPPL.jl" begin if GROUP == "All" || GROUP == "Tests" - include("deprecations.jl") include("varname.jl") include("abstractprobprog.jl") end diff --git a/test/varname.jl b/test/varname.jl index 3fb2733..e074173 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -56,6 +56,24 @@ end @test !inspace(@varname(z), space) end + @testset "optic normalisation" begin + # Push the limits a bit with four optics, one of which is identity, and + # we'll parenthesise them in every possible way. (Some of these are + # going to be equal even before normalisation, but we should test that + # `normalise` works regardless of how Base or Accessors.jl define + # associativity.) + op1 = ((@o _.c) ∘ (@o _.b)) ∘ identity ∘ (@o _.a) + op2 = (@o _.c) ∘ ((@o _.b) ∘ identity) ∘ (@o _.a) + op3 = (@o _.c) ∘ (@o _.b) ∘ (identity ∘ (@o _.a)) + op4 = ((@o _.c) ∘ (@o _.b) ∘ identity) ∘ (@o _.a) + op5 = (@o _.c) ∘ ((@o _.b) ∘ identity ∘ (@o _.a)) + op6 = (@o _.c) ∘ (@o _.b) ∘ identity ∘ (@o _.a) + for op in (op1, op2, op3, op4, op5, op6) + @test AbstractPPL.normalise(op) == (@o _.c) ∘ (@o _.b) ∘ (@o _.a) + end + # Prefix and unprefix also provide further testing for normalisation. + end + @testset "construction & concretization" begin i = 1:10 j = 2:2:5 @@ -235,17 +253,44 @@ end end @testset "prefix and unprefix" begin - @test prefix(@varname(y), @varname(x)) == @varname(x.y) - @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) - @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) - @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) - @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) - - @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) - @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) - @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) - @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) + @testset "basic cases" begin + @test prefix(@varname(y), @varname(x)) == @varname(x.y) + @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) + @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) + @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) + @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) + + @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) + @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) + @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) + @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) + end + + @testset "round-trip" begin + # These seem similar to the ones above, but in the past they used + # to error because of issues with un-normalised ComposedFunction + # optics. We explicitly test round-trip (un)prefixing here to make + # sure that there aren't any regressions. + # This tuple is probably overkill, but the tests are super fast + # anyway. + vns = ( + @varname(p), + @varname(q), + @varname(r[1]), + @varname(s.a), + @varname(t[1].a), + @varname(u[1].a.b), + @varname(v.a[1][2].b.c.d[3]) + ) + for vn1 in vns + for vn2 in vns + prefixed = prefix(vn1, vn2) + unprefixed = unprefix(prefixed, vn2) + @test unprefixed == vn1 + end + end + end end end From de95bc5994845892659dc542fdf322fe2e41a6e4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 5 Jul 2025 18:08:44 +0100 Subject: [PATCH 2/4] Add another test --- test/varname.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/varname.jl b/test/varname.jl index e074173..40260da 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -287,6 +287,7 @@ end for vn1 in vns for vn2 in vns prefixed = prefix(vn1, vn2) + @test subsumes(vn2, prefixed) unprefixed = unprefix(prefixed, vn2) @test unprefixed == vn1 end From 7bf3162f8741b78d66ccd28782e08cdc724d9035 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 5 Jul 2025 18:10:49 +0100 Subject: [PATCH 3/4] Improve changelog --- HISTORY.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index a91b63b..65a678f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,8 @@ Removed the constructors `VarName(vn, optic)` (this wasn't deprecated, but was dangerous as it would silently discard the existing optic in `vn`), and `VarName(vn, ::Tuple)` (which was deprecated). +Usage of `VarName(vn, optic)` can be directly replaced with `VarName{getsym(vn)}(optic)`. + ### Optic normalisation In the inner constructor of a VarName, its optic is now normalised to ensure that the associativity of ComposedFunction is always the same, and that compositions with identity are removed. From 7bc83d864216fb7afa5bdbb5e9a15f4f4602dc65 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 6 Jul 2025 11:55:39 +0100 Subject: [PATCH 4/4] Add clarifying comment --- src/varname.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/varname.jl b/src/varname.jl index 9c5e1e3..83f6f6f 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -2,6 +2,7 @@ using Accessors using Accessors: PropertyLens, IndexLens, DynamicIndexLens using JSON: JSON +# nb. ComposedFunction is the same as Accessors.ComposedOptic const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedFunction} """