Skip to content

Commit a8890b6

Browse files
committed
Implement prefix / unprefix
1 parent 700a70b commit a8890b6

File tree

3 files changed

+168
-1
lines changed

3 files changed

+168
-1
lines changed

src/AbstractPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ export VarName,
1414
index_to_dict,
1515
dict_to_index,
1616
varname_to_string,
17-
string_to_varname
17+
string_to_varname,
18+
prefix,
19+
unprefix
1820

1921
# Abstract model functions
2022
export AbstractProbabilisticProgram,

src/varname.jl

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,8 @@ function vsym(expr::Expr)
766766
end
767767
end
768768

769+
### Serialisation to JSON / string
770+
769771
# String constants for each index type that we support serialisation /
770772
# deserialisation of
771773
const _BASE_INTEGER_TYPE = "Base.Integer"
@@ -936,3 +938,152 @@ Convert a string representation of a `VarName` back to a `VarName`. The string
936938
should have been generated by `varname_to_string`.
937939
"""
938940
string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str))
941+
942+
### Prefixing and unprefixing
943+
944+
"""
945+
_strip_identity(optic)
946+
947+
Remove an inner layer of the identity lens from a composed optic.
948+
"""
949+
_strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} = o.outer
950+
_strip_identity(o::Base.ComposedFunction) = o
951+
_strip_identity(o::Accessors.PropertyLens) = o
952+
_strip_identity(o::Accessors.IndexLens) = o
953+
_strip_identity(o::typeof(identity)) = o
954+
955+
"""
956+
_inner(optic)
957+
958+
Get the innermost (non-identity) layer of an optic.
959+
960+
```jldoctest; setup=:(using Accessors)
961+
julia> AbstractPPL._inner(Accessors.@o _.a.b.c)
962+
(@o _.a)
963+
964+
julia> AbstractPPL._inner(Accessors.@o _[1][2][3])
965+
(@o _[1])
966+
967+
julia> AbstractPPL._inner(Accessors.@o _)
968+
identity (generic function with 1 method)
969+
```
970+
"""
971+
_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner
972+
function _inner(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
973+
return _strip_identity(o.outer)
974+
end
975+
_inner(o::Accessors.PropertyLens) = o
976+
_inner(o::Accessors.IndexLens) = o
977+
_inner(o::typeof(identity)) = o
978+
979+
"""
980+
_outer(optic)
981+
982+
Get the outer layer of an optic.
983+
984+
```jldoctest; setup=:(using Accessors)
985+
julia> AbstractPPL._outer(Accessors.@o _.a.b.c)
986+
(@o _.b.c)
987+
988+
julia> AbstractPPL._outer(Accessors.@o _[1][2][3])
989+
(@o _[2][3])
990+
991+
julia> AbstractPPL._outer(Accessors.@o _.a)
992+
identity (generic function with 1 method)
993+
994+
julia> AbstractPPL._outer(Accessors.@o _[1])
995+
identity (generic function with 1 method)
996+
997+
julia> AbstractPPL._outer(Accessors.@o _)
998+
identity (generic function with 1 method)
999+
```
1000+
"""
1001+
_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = _strip_identity(o.outer)
1002+
_outer(::Accessors.PropertyLens) = identity
1003+
_outer(::Accessors.IndexLens) = identity
1004+
_outer(::typeof(identity)) = identity
1005+
1006+
"""
1007+
optic_to_vn(optic)
1008+
1009+
Convert an Accessors optic to a VarName. This is best explained through
1010+
examples.
1011+
1012+
```jldoctest; setup=:(using Accessors)
1013+
julia> AbstractPPL.optic_to_vn(Accessors.@o _.a)
1014+
a
1015+
1016+
julia> AbstractPPL.optic_to_vn(Accessors.@o _.a.b)
1017+
a.b
1018+
1019+
julia> AbstractPPL.optic_to_vn(Accessors.@o _.a[1])
1020+
a[1]
1021+
```
1022+
1023+
The outermost layer of the optic (technically, what Accessors.jl calls the
1024+
'innermost') must be a `PropertyLens`, or else it will fail. This is because a
1025+
VarName needs to have a symbol.
1026+
1027+
```jldoctest; setup=:(using Accessors)
1028+
julia> AbstractPPL.optic_to_vn(Accessors.@o _[1])
1029+
ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarName
1030+
[...]
1031+
```
1032+
"""
1033+
function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym}
1034+
return VarName{sym}()
1035+
end
1036+
function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer}
1037+
return optic_to_vn(o.outer)
1038+
end
1039+
function optic_to_vn(
1040+
o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}}
1041+
) where {Outer,sym}
1042+
return VarName{sym}(o.outer)
1043+
end
1044+
function optic_to_vn(@nospecialize(o))
1045+
msg = "optic_to_vn: could not convert optic `$o` to a VarName"
1046+
throw(ArgumentError(msg))
1047+
end
1048+
1049+
unprefix_optic(o, ::typeof(identity)) = o # Base case
1050+
function unprefix_optic(optic, optic_prefix)
1051+
# strip one layer of the optic and check for equality
1052+
inner = _inner(optic)
1053+
inner_prefix = _inner(optic_prefix)
1054+
if inner != inner_prefix
1055+
msg = "could not remove prefix $(optic_prefix) from optic $(optic)"
1056+
throw(ArgumentError(msg))
1057+
end
1058+
# recurse
1059+
return unprefix_optic(_outer(optic), _outer(optic_prefix))
1060+
end
1061+
1062+
function unprefix(
1063+
vn::VarName{sym_vn}, prefix::VarName{sym_prefix}
1064+
) where {sym_vn,sym_prefix}
1065+
if sym_vn != sym_prefix
1066+
msg = "could not remove prefix $(prefix) from VarName $(vn)"
1067+
throw(ArgumentError(msg))
1068+
end
1069+
optic_vn = getoptic(vn)
1070+
optic_prefix = getoptic(prefix)
1071+
return optic_to_vn(unprefix_optic(optic_vn, optic_prefix))
1072+
end
1073+
1074+
function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix}
1075+
optic_vn = getoptic(vn)
1076+
optic_prefix = getoptic(prefix)
1077+
# Special case `identity` to avoid having ComposedFunctions with identity
1078+
if optic_vn == identity
1079+
new_inner_optic_vn = PropertyLens{sym_vn}()
1080+
else
1081+
new_inner_optic_vn = optic_vn PropertyLens{sym_vn}()
1082+
end
1083+
if optic_prefix == identity
1084+
new_optic_vn = new_inner_optic_vn
1085+
else
1086+
new_optic_vn = new_inner_optic_vn optic_prefix
1087+
end
1088+
return VarName{sym_prefix}(new_optic_vn)
1089+
end

test/varname.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,18 @@ end
233233
# Serialisation should now work
234234
@test string_to_varname(varname_to_string(vn)) == vn
235235
end
236+
237+
@testset "prefix and unprefix" begin
238+
@test prefix(@varname(y), @varname(x)) == @varname(x.y)
239+
@test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y)
240+
@test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y)
241+
@test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1])
242+
@test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a)
243+
244+
@test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1])
245+
@test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y)
246+
@test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y)
247+
@test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a)
248+
@test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n))
249+
end
236250
end

0 commit comments

Comments
 (0)