Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## 0.12.0

### VarName constructors

Removed the constructors `VarName(vn, optic)` (this wasn't deprecated, but was dangerous as it would silently discard the existing optic in `vn`), and `VarName(vn, ::Tuple)` (which was deprecated).

Usage of `VarName(vn, optic)` can be directly replaced with `VarName{getsym(vn)}(optic)`.

### Optic normalisation

In the inner constructor of a VarName, its optic is now normalised to ensure that the associativity of ComposedFunction is always the same, and that compositions with identity are removed.
This helps to prevent subtle bugs where VarNames with semantically equal optics are not considered equal.

## 0.11.0

Added the `prefix(vn::VarName, vn_prefix::VarName)` and `unprefix(vn::VarName, vn_prefix::VarName)` functions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.11.0"
version = "0.12.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
1 change: 0 additions & 1 deletion src/AbstractPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,5 @@ include("varname.jl")
include("abstractmodeltrace.jl")
include("abstractprobprog.jl")
include("evaluate.jl")
include("deprecations.jl")

end # module
2 changes: 0 additions & 2 deletions src/deprecations.jl

This file was deleted.

140 changes: 76 additions & 64 deletions src/varname.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Accessors
using Accessors: ComposedOptic, PropertyLens, IndexLens, DynamicIndexLens
using Accessors: PropertyLens, IndexLens, DynamicIndexLens
using JSON: JSON

const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedOptic}
const ALLOWED_OPTICS = Union{typeof(identity),PropertyLens,IndexLens,ComposedFunction}

"""
VarName{sym}(optic=identity)
Expand Down Expand Up @@ -31,10 +31,11 @@
x[:, 1][2]
```
"""
struct VarName{sym,T}
struct VarName{sym,T<:ALLOWED_OPTICS}
optic::T

function VarName{sym}(optic=identity) where {sym}
optic = normalise(optic)
if !is_static_optic(typeof(optic))
throw(
ArgumentError(
Expand All @@ -53,42 +54,68 @@
one or a composition of `DynamicIndexLens`; and undefined otherwise.
"""
is_static_optic(::Type{<:Union{typeof(identity),PropertyLens,IndexLens}}) = true
function is_static_optic(::Type{ComposedOptic{LO,LI}}) where {LO,LI}
function is_static_optic(::Type{ComposedFunction{LO,LI}}) where {LO,LI}
return is_static_optic(LO) && is_static_optic(LI)
end
is_static_optic(::Type{<:DynamicIndexLens}) = false

# A bit of backwards compatibility.
VarName{sym}(indexing::Tuple) where {sym} = VarName{sym}(tupleindex2optic(indexing))

"""
VarName(vn::VarName, optic)
VarName(vn::VarName, indexing::Tuple)
normalise(optic)

Return a copy of `vn` with a new index `optic`/`indexing`.
Enforce that compositions of optics are always nested in the same way, in that
a ComposedFunction never has a ComposedFunction as its inner lens. Thus, for
example,

```jldoctest; setup=:(using Accessors)
julia> VarName(@varname(x[1][2:3]), Accessors.IndexLens((2,)))
x[2]
julia> op1 = ((@o _.c) ∘ (@o _.b)) ∘ (@o _.a)
(@o _.a.b.c)

julia> VarName(@varname(x[1][2:3]), ((2,),))
x[2]
julia> op2 = (@o _.c) ∘ ((@o _.b) ∘ (@o _.a))
(@o _.c) ∘ ((@o _.a.b))

julia> VarName(@varname(x[1][2:3]))
x
julia> op1 == op2
false

julia> AbstractPPL.normalise(op1) == AbstractPPL.normalise(op2) == @o _.a.b.c
true
```
"""
VarName(vn::VarName, optic=identity) = VarName{getsym(vn)}(optic)

function VarName(vn::VarName, indexing::Tuple)
return VarName{getsym(vn)}(tupleindex2optic(indexing))
end
This function also removes redundant `identity` optics from ComposedFunctions:

```jldoctest; setup=:(using Accessors)
julia> op3 = ((@o _.b) ∘ identity) ∘ (@o _.a)
(@o identity(_.a).b)

tupleindex2optic(indexing::Tuple{}) = identity
tupleindex2optic(indexing::Tuple{<:Tuple}) = IndexLens(first(indexing)) # TODO: rest?
function tupleindex2optic(indexing::Tuple)
return IndexLens(first(indexing)) ∘ tupleindex2optic(indexing[2:end])
julia> op4 = (@o _.b) ∘ (identity ∘ (@o _.a))
(@o _.b) ∘ ((@o identity(_.a)))

julia> AbstractPPL.normalise(op3) == AbstractPPL.normalise(op4) == @o _.a.b
true
```
"""
function normalise(o::ComposedFunction{Outer,<:ComposedFunction}) where {Outer}
# `o` is currently (outer ∘ (inner_outer ∘ inner_inner)).
# We want to change this to:
# o = (outer ∘ inner_outer) ∘ inner_inner
inner_inner = o.inner.inner
inner_outer = o.inner.outer
# Recursively call normalise because inner_inner could itself be a
# ComposedFunction
return normalise((o.outer ∘ inner_outer) ∘ inner_inner)
end
function normalise(o::ComposedFunction{Outer,typeof(identity)} where {Outer})
# strip outer identity
return normalise(o.outer)
end
function normalise(o::ComposedFunction{typeof(identity),Inner} where {Inner})
# strip inner identity
return normalise(o.inner)
end
normalise(o::ComposedFunction) = normalise(o.outer) ∘ o.inner
normalise(o::ALLOWED_OPTICS) = o
# These two methods are needed to avoid method ambiguity.
normalise(o::ComposedFunction{typeof(identity),<:ComposedFunction}) = normalise(o.inner)
normalise(::ComposedFunction{typeof(identity),typeof(identity)}) = identity

"""
getsym(vn::VarName)
Expand All @@ -105,7 +132,7 @@
:y
```
"""
getsym(vn::VarName{sym}) where {sym} = sym
getsym(::VarName{sym}) where {sym} = sym

"""
getoptic(vn::VarName)
Expand Down Expand Up @@ -154,15 +181,8 @@
end

# Allow compositions with optic.
function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym,<:ALLOWED_OPTICS}) where {sym}
vn_optic = getoptic(vn)
if vn_optic == identity
return VarName{sym}(optic)
elseif optic == identity
return vn
else
return VarName{sym}(optic ∘ vn_optic)
end
function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym}) where {sym}
return VarName{sym}(optic ∘ getoptic(vn))
end

Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getoptic(vn)), h)
Expand Down Expand Up @@ -299,17 +319,17 @@
subsumes(::typeof(identity), ::ALLOWED_OPTICS) = true
subsumes(::ALLOWED_OPTICS, ::typeof(identity)) = false

function subsumes(t::ComposedOptic, u::ComposedOptic)
function subsumes(t::ComposedFunction, u::ComposedFunction)
return subsumes(t.outer, u.outer) && subsumes(t.inner, u.inner)
end

# If `t` is still a composed lens, then there is no way it can subsume `u` since `u` is a
# leaf of the "lens-tree".
subsumes(t::ComposedOptic, u::PropertyLens) = false
subsumes(t::ComposedFunction, u::PropertyLens) = false
# Here we need to check if `u.inner` (i.e. the next lens to be applied from `u`) is
# subsumed by `t`, since this would mean that the rest of the composition is also subsumed
# by `t`.
subsumes(t::PropertyLens, u::ComposedOptic) = subsumes(t, u.inner)
subsumes(t::PropertyLens, u::ComposedFunction) = subsumes(t, u.inner)

# For `PropertyLens` either they have the same `name` and thus they are indeed the same.
subsumes(t::PropertyLens{name}, u::PropertyLens{name}) where {name} = true
Expand All @@ -321,8 +341,8 @@
# FIXME: Does not correctly handle cases such as `subsumes(x, x[:])`
# (but neither did old implementation).
function subsumes(
t::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
u::Union{IndexLens,ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens}},
t::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}},
u::Union{IndexLens,ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens}},
)
return subsumes_indices(t, u)
end
Expand Down Expand Up @@ -415,7 +435,7 @@
"""
combine_indices(optic::ALLOWED_OPTICS) = (), optic
combine_indices(optic::IndexLens) = (optic.indices,), nothing
function combine_indices(optic::ComposedOptic{<:ALLOWED_OPTICS,<:IndexLens})
function combine_indices(optic::ComposedFunction{<:ALLOWED_OPTICS,<:IndexLens})
indices, next = combine_indices(optic.outer)
return (optic.inner.indices, indices...), next
end
Expand Down Expand Up @@ -505,9 +525,9 @@
function concretize(I::IndexLens, x)
return IndexLens(reconcretize_index.(I.indices, to_indices(x, I.indices)))
end
function concretize(I::ComposedOptic, x)
function concretize(I::ComposedFunction, x)
x_inner = I.inner(x) # TODO: get view here
return ComposedOptic(concretize(I.outer, x_inner), concretize(I.inner, x))
return ComposedFunction(concretize(I.outer, x_inner), concretize(I.inner, x))
end

"""
Expand All @@ -533,7 +553,7 @@
ConcretizedSlice(Base.OneTo(100))
```
"""
concretize(vn::VarName, x) = VarName(vn, concretize(getoptic(vn), x))
concretize(vn::VarName{sym}, x) where {sym} = VarName{sym}(concretize(getoptic(vn), x))

"""
@varname(expr, concretize=false)
Expand Down Expand Up @@ -872,7 +892,7 @@
return Dict("type" => "property", "field" => String(sym))
end
optic_to_dict(i::IndexLens) = Dict("type" => "index", "indices" => index_to_dict(i.indices))
function optic_to_dict(c::ComposedOptic)
function optic_to_dict(c::ComposedFunction)
return Dict(
"type" => "composed",
"outer" => optic_to_dict(c.outer),
Expand Down Expand Up @@ -1036,32 +1056,34 @@
function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym}
return VarName{sym}()
end
function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
return optic_to_vn(o.outer)
end
function optic_to_vn(
o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}}
) where {Outer,sym}
return VarName{sym}(o.outer)
end
optic_to_vn(o::Base.ComposedFunction) = optic_to_vn(normalise(o))

Check warning on line 1064 in src/varname.jl

View check run for this annotation

Codecov / codecov/patch

src/varname.jl#L1064

Added line #L1064 was not covered by tests
function optic_to_vn(@nospecialize(o))
msg = "optic_to_vn: could not convert optic `$o` to a VarName"
throw(ArgumentError(msg))
end

unprefix_optic(o, ::typeof(identity)) = o # Base case
function unprefix_optic(optic, optic_prefix)
# Technically `unprefix_optic` only receives optics that were part of
# VarNames, so the optics should already be normalised (in the inner
# constructor of the VarName). However I guess it doesn't hurt to do it
# again to be safe.
optic = normalise(optic)
optic_prefix = normalise(optic_prefix)
# strip one layer of the optic and check for equality
inner = _inner(_strip_identity(optic))
inner_prefix = _inner(_strip_identity(optic_prefix))
inner = _inner(optic)
inner_prefix = _inner(optic_prefix)
if inner != inner_prefix
msg = "could not remove prefix $(optic_prefix) from optic $(optic)"
throw(ArgumentError(msg))
end
# recurse
return unprefix_optic(
_outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix))
)
return unprefix_optic(_outer(optic), _outer(optic_prefix))
end

"""
Expand Down Expand Up @@ -1115,16 +1137,6 @@
function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix}
optic_vn = getoptic(vn)
optic_prefix = getoptic(prefix)
# Special case `identity` to avoid having ComposedFunctions with identity
if optic_vn == identity
new_inner_optic_vn = PropertyLens{sym_vn}()
else
new_inner_optic_vn = optic_vn ∘ PropertyLens{sym_vn}()
end
if optic_prefix == identity
new_optic_vn = new_inner_optic_vn
else
new_optic_vn = new_inner_optic_vn ∘ optic_prefix
end
new_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() ∘ optic_prefix
return VarName{sym_prefix}(new_optic_vn)
end
4 changes: 0 additions & 4 deletions test/deprecations.jl

This file was deleted.

9 changes: 0 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
# Activate test environment on older Julia versions
if VERSION < v"1.2"
using Pkg: Pkg
Pkg.activate(@__DIR__)
Pkg.develop(Pkg.PackageSpec(; path=dirname(@__DIR__)))
Pkg.instantiate()
end

using AbstractPPL
using Documenter
using Test
Expand All @@ -14,7 +6,6 @@ const GROUP = get(ENV, "GROUP", "All")

@testset "AbstractPPL.jl" begin
if GROUP == "All" || GROUP == "Tests"
include("deprecations.jl")
include("varname.jl")
include("abstractprobprog.jl")
end
Expand Down
Loading
Loading