diff --git a/HISTORY.md b/HISTORY.md new file mode 100644 index 0000000..65a678f --- /dev/null +++ b/HISTORY.md @@ -0,0 +1,16 @@ +## 0.12.0 + +### VarName constructors + +Removed the constructors `VarName(vn, optic)` (this wasn't deprecated, but was dangerous as it would silently discard the existing optic in `vn`), and `VarName(vn, ::Tuple)` (which was deprecated). + +Usage of `VarName(vn, optic)` can be directly replaced with `VarName{getsym(vn)}(optic)`. + +### Optic normalisation + +In the inner constructor of a VarName, its optic is now normalised to ensure that the associativity of ComposedFunction is always the same, and that compositions with identity are removed. +This helps to prevent subtle bugs where VarNames with semantically equal optics are not considered equal. + +## 0.11.0 + +Added the `prefix(vn::VarName, vn_prefix::VarName)` and `unprefix(vn::VarName, vn_prefix::VarName)` functions. diff --git a/Project.toml b/Project.toml index 42f84f2..ecccd56 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.11.0" +version = "0.12.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index d0df706..40da231 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -29,6 +29,5 @@ include("varname.jl") include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") -include("deprecations.jl") end # module diff --git a/src/deprecations.jl b/src/deprecations.jl deleted file mode 100644 index 24901bb..0000000 --- a/src/deprecations.jl +++ /dev/null @@ -1,2 +0,0 @@ -@deprecate VarName(sym::Symbol) VarName{sym}() -@deprecate VarName(sym::Symbol, indexing::Tuple) VarName{sym}(indexing) diff --git a/src/varname.jl b/src/varname.jl index 38d558c..83f6f6f 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,8 +1,9 @@ using Accessors -using Accessors: ComposedOptic, PropertyLens, IndexLens, DynamicIndexLens +using Accessors: PropertyLens, IndexLens, DynamicIndexLens using JSON: JSON -const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedOptic} +# nb. ComposedFunction is the same as Accessors.ComposedOptic +const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedFunction} """ VarName{sym}(optic=identity) @@ -31,10 +32,11 @@ julia> @varname x[:, 1][1+1] x[:, 1][2] ``` """ -struct VarName{sym,T} +struct VarName{sym,T<:ALLOWED_OPTICS} optic::T function VarName{sym}(optic=identity) where {sym} + optic = normalise(optic) if !is_static_optic(typeof(optic)) throw( ArgumentError( @@ -53,42 +55,68 @@ Return `true` if `l` is one or a composition of `identity`, `PropertyLens`, and one or a composition of `DynamicIndexLens`; and undefined otherwise. """ is_static_optic(::Type{<:Union{typeof(identity),PropertyLens,IndexLens}}) = true -function is_static_optic(::Type{ComposedOptic{LO,LI}}) where {LO,LI} +function is_static_optic(::Type{ComposedFunction{LO,LI}}) where {LO,LI} return is_static_optic(LO) && is_static_optic(LI) end is_static_optic(::Type{<:DynamicIndexLens}) = false -# A bit of backwards compatibility. -VarName{sym}(indexing::Tuple) where {sym} = VarName{sym}(tupleindex2optic(indexing)) - """ - VarName(vn::VarName, optic) - VarName(vn::VarName, indexing::Tuple) + normalise(optic) -Return a copy of `vn` with a new index `optic`/`indexing`. +Enforce that compositions of optics are always nested in the same way, in that +a ComposedFunction never has a ComposedFunction as its inner lens. Thus, for +example, ```jldoctest; setup=:(using Accessors) -julia> VarName(@varname(x[1][2:3]), Accessors.IndexLens((2,))) -x[2] +julia> op1 = ((@o _.c) ∘ (@o _.b)) ∘ (@o _.a) +(@o _.a.b.c) -julia> VarName(@varname(x[1][2:3]), ((2,),)) -x[2] +julia> op2 = (@o _.c) ∘ ((@o _.b) ∘ (@o _.a)) +(@o _.c) ∘ ((@o _.a.b)) -julia> VarName(@varname(x[1][2:3])) -x +julia> op1 == op2 +false + +julia> AbstractPPL.normalise(op1) == AbstractPPL.normalise(op2) == @o _.a.b.c +true ``` -""" -VarName(vn::VarName, optic=identity) = VarName{getsym(vn)}(optic) -function VarName(vn::VarName, indexing::Tuple) - return VarName{getsym(vn)}(tupleindex2optic(indexing)) -end +This function also removes redundant `identity` optics from ComposedFunctions: + +```jldoctest; setup=:(using Accessors) +julia> op3 = ((@o _.b) ∘ identity) ∘ (@o _.a) +(@o identity(_.a).b) -tupleindex2optic(indexing::Tuple{}) = identity -tupleindex2optic(indexing::Tuple{<:Tuple}) = IndexLens(first(indexing)) # TODO: rest? -function tupleindex2optic(indexing::Tuple) - return IndexLens(first(indexing)) ∘ tupleindex2optic(indexing[2:end]) +julia> op4 = (@o _.b) ∘ (identity ∘ (@o _.a)) +(@o _.b) ∘ ((@o identity(_.a))) + +julia> AbstractPPL.normalise(op3) == AbstractPPL.normalise(op4) == @o _.a.b +true +``` +""" +function normalise(o::ComposedFunction{Outer,<:ComposedFunction}) where {Outer} + # `o` is currently (outer ∘ (inner_outer ∘ inner_inner)). + # We want to change this to: + # o = (outer ∘ inner_outer) ∘ inner_inner + inner_inner = o.inner.inner + inner_outer = o.inner.outer + # Recursively call normalise because inner_inner could itself be a + # ComposedFunction + return normalise((o.outer ∘ inner_outer) ∘ inner_inner) +end +function normalise(o::ComposedFunction{Outer,typeof(identity)} where {Outer}) + # strip outer identity + return normalise(o.outer) +end +function normalise(o::ComposedFunction{typeof(identity),Inner} where {Inner}) + # strip inner identity + return normalise(o.inner) end +normalise(o::ComposedFunction) = normalise(o.outer) ∘ o.inner +normalise(o::ALLOWED_OPTICS) = o +# These two methods are needed to avoid method ambiguity. +normalise(o::ComposedFunction{typeof(identity),<:ComposedFunction}) = normalise(o.inner) +normalise(::ComposedFunction{typeof(identity),typeof(identity)}) = identity """ getsym(vn::VarName) @@ -105,7 +133,7 @@ julia> getsym(@varname(y)) :y ``` """ -getsym(vn::VarName{sym}) where {sym} = sym +getsym(::VarName{sym}) where {sym} = sym """ getoptic(vn::VarName) @@ -154,15 +182,8 @@ function Accessors.set(obj, vn::VarName{sym}, value) where {sym} end # Allow compositions with optic. -function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym,<:ALLOWED_OPTICS}) where {sym} - vn_optic = getoptic(vn) - if vn_optic == identity - return VarName{sym}(optic) - elseif optic == identity - return vn - else - return VarName{sym}(optic ∘ vn_optic) - end +function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym}) where {sym} + return VarName{sym}(optic ∘ getoptic(vn)) end Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h) @@ -299,17 +320,17 @@ subsumes(::typeof(identity), ::typeof(identity)) = true subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false -function subsumes(t::ComposedOptic, u::ComposedOptic) +function subsumes(t::ComposedFunction, u::ComposedFunction) return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner) end # If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a # leaf of the "lens-tree". -subsumes(t::ComposedOptic, u::PropertyLens) = false +subsumes(t::ComposedFunction, u::PropertyLens) = false # Here we need to check if `u.inner` (i.e. the next lens to be applied from `u`) is # subsumed by `t`, since this would mean that the rest of the composition is also subsumed # by `t`. -subsumes(t::PropertyLens, u::ComposedOptic) = subsumes(t, u.inner) +subsumes(t::PropertyLens, u::ComposedFunction) = subsumes(t, u.inner) # For `PropertyLens` either they have the same `name` and thus they are indeed the same. subsumes(t::PropertyLens{name}, u::PropertyLens{name}) where {name} = true @@ -321,8 +342,8 @@ subsumes(t::PropertyLens, u::PropertyLens) = false # FIXME: Does not correctly handle cases such as `subsumes(x, x[:])` # (but neither did old implementation). function subsumes( - t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}, - u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}}, + t::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, + u::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}}, ) return subsumes_indices(t, u) end @@ -415,7 +436,7 @@ The result is compatible with [`subsumes_indices`](@ref) for `Tuple` input. """ combine_indices(optic::ALLOWED_OPTICS) = (), optic combine_indices(optic::IndexLens) = (optic.indices,), nothing -function combine_indices(optic::ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}) +function combine_indices(optic::ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}) indices, next = combine_indices(optic.outer) return (optic.inner.indices, indices...), next end @@ -505,9 +526,9 @@ concretize(I::DynamicIndexLens, x) = concretize(IndexLens(I.f(x)), x) function concretize(I::IndexLens, x) return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices))) end -function concretize(I::ComposedOptic, x) +function concretize(I::ComposedFunction, x) x_inner = I.inner(x) # TODO: get view here - return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x)) + return ComposedFunction(concretize(I.outer, x_inner), concretize(I.inner, x)) end """ @@ -533,7 +554,7 @@ julia> # The underlying value is concretized, though: ConcretizedSlice(Base.OneTo(100)) ``` """ -concretize(vn::VarName, x) = VarName(vn, concretize(getoptic(vn), x)) +concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x)) """ @varname(expr, concretize=false) @@ -872,7 +893,7 @@ function optic_to_dict(::PropertyLens{sym}) where {sym} return Dict("type" => "property", "field" => String(sym)) end optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices)) -function optic_to_dict(c::ComposedOptic) +function optic_to_dict(c::ComposedFunction) return Dict( "type" => "composed", "outer" => optic_to_dict(c.outer), @@ -1036,14 +1057,12 @@ ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarN function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} return VarName{sym}() end -function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} - return optic_to_vn(o.outer) -end function optic_to_vn( o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}} ) where {Outer,sym} return VarName{sym}(o.outer) end +optic_to_vn(o::Base.ComposedFunction) = optic_to_vn(normalise(o)) function optic_to_vn(@nospecialize(o)) msg = "optic_to_vn: could not convert optic `$o` to a VarName" throw(ArgumentError(msg)) @@ -1051,17 +1070,21 @@ end unprefix_optic(o, ::typeof(identity)) = o # Base case function unprefix_optic(optic, optic_prefix) + # Technically `unprefix_optic` only receives optics that were part of + # VarNames, so the optics should already be normalised (in the inner + # constructor of the VarName). However I guess it doesn't hurt to do it + # again to be safe. + optic = normalise(optic) + optic_prefix = normalise(optic_prefix) # strip one layer of the optic and check for equality - inner = _inner(_strip_identity(optic)) - inner_prefix = _inner(_strip_identity(optic_prefix)) + inner = _inner(optic) + inner_prefix = _inner(optic_prefix) if inner != inner_prefix msg = "could not remove prefix $(optic_prefix) from optic $(optic)" throw(ArgumentError(msg)) end # recurse - return unprefix_optic( - _outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix)) - ) + return unprefix_optic(_outer(optic), _outer(optic_prefix)) end """ @@ -1115,16 +1138,6 @@ y[1].x.a function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix} optic_vn = getoptic(vn) optic_prefix = getoptic(prefix) - # Special case `identity` to avoid having ComposedFunctions with identity - if optic_vn == identity - new_inner_optic_vn = PropertyLens{sym_vn}() - else - new_inner_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() - end - if optic_prefix == identity - new_optic_vn = new_inner_optic_vn - else - new_optic_vn = new_inner_optic_vn ∘ optic_prefix - end + new_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() ∘ optic_prefix return VarName{sym_prefix}(new_optic_vn) end diff --git a/test/deprecations.jl b/test/deprecations.jl deleted file mode 100644 index dffbe4d..0000000 --- a/test/deprecations.jl +++ /dev/null @@ -1,4 +0,0 @@ -@testset "deprecations.jl" begin - @test (@test_deprecated VarName(:x)) == VarName{:x}() - @test (@test_deprecated VarName(:x, ((1,), (:, 2)))) == VarName{:x}(((1,), (:, 2))) -end diff --git a/test/runtests.jl b/test/runtests.jl index a11a829..8be65eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,3 @@ -# Activate test environment on older Julia versions -if VERSION < v"1.2" - using Pkg: Pkg - Pkg.activate(@__DIR__) - Pkg.develop(Pkg.PackageSpec(; path=dirname(@__DIR__))) - Pkg.instantiate() -end - using AbstractPPL using Documenter using Test @@ -14,7 +6,6 @@ const GROUP = get(ENV, "GROUP", "All") @testset "AbstractPPL.jl" begin if GROUP == "All" || GROUP == "Tests" - include("deprecations.jl") include("varname.jl") include("abstractprobprog.jl") end diff --git a/test/varname.jl b/test/varname.jl index 3fb2733..40260da 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -56,6 +56,24 @@ end @test !inspace(@varname(z), space) end + @testset "optic normalisation" begin + # Push the limits a bit with four optics, one of which is identity, and + # we'll parenthesise them in every possible way. (Some of these are + # going to be equal even before normalisation, but we should test that + # `normalise` works regardless of how Base or Accessors.jl define + # associativity.) + op1 = ((@o _.c) ∘ (@o _.b)) ∘ identity ∘ (@o _.a) + op2 = (@o _.c) ∘ ((@o _.b) ∘ identity) ∘ (@o _.a) + op3 = (@o _.c) ∘ (@o _.b) ∘ (identity ∘ (@o _.a)) + op4 = ((@o _.c) ∘ (@o _.b) ∘ identity) ∘ (@o _.a) + op5 = (@o _.c) ∘ ((@o _.b) ∘ identity ∘ (@o _.a)) + op6 = (@o _.c) ∘ (@o _.b) ∘ identity ∘ (@o _.a) + for op in (op1, op2, op3, op4, op5, op6) + @test AbstractPPL.normalise(op) == (@o _.c) ∘ (@o _.b) ∘ (@o _.a) + end + # Prefix and unprefix also provide further testing for normalisation. + end + @testset "construction & concretization" begin i = 1:10 j = 2:2:5 @@ -235,17 +253,45 @@ end end @testset "prefix and unprefix" begin - @test prefix(@varname(y), @varname(x)) == @varname(x.y) - @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) - @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) - @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) - @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) - - @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) - @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) - @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) - @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) - @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) + @testset "basic cases" begin + @test prefix(@varname(y), @varname(x)) == @varname(x.y) + @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) + @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) + @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) + @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) + + @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) + @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) + @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) + @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) + end + + @testset "round-trip" begin + # These seem similar to the ones above, but in the past they used + # to error because of issues with un-normalised ComposedFunction + # optics. We explicitly test round-trip (un)prefixing here to make + # sure that there aren't any regressions. + # This tuple is probably overkill, but the tests are super fast + # anyway. + vns = ( + @varname(p), + @varname(q), + @varname(r[1]), + @varname(s.a), + @varname(t[1].a), + @varname(u[1].a.b), + @varname(v.a[1][2].b.c.d[3]) + ) + for vn1 in vns + for vn2 in vns + prefixed = prefix(vn1, vn2) + @test subsumes(vn2, prefixed) + unprefixed = unprefix(prefixed, vn2) + @test unprefixed == vn1 + end + end + end end end