Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Accessors"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
authors = ["Takafumi Arakaki <[email protected]>", "Jan Weidner <[email protected]> and contributors"]
version = "0.1.21"
version = "0.1.23"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
118 changes: 94 additions & 24 deletions src/getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ julia> getall(obj, @optic _ |> Elements() |> last)
"""
function getall end

# implementations for individual noncomposite optics

getall(obj::Union{Tuple, AbstractVector}, ::Elements) = obj
getall(obj::Union{NamedTuple}, ::Elements) = values(obj)
getall(obj::AbstractArray, ::Elements) = vec(obj)
Expand All @@ -29,17 +31,63 @@ getall(obj, ::Properties) = getproperties(obj) |> values
getall(obj, o::If) = o.modify_condition(obj) ? (obj,) : ()
getall(obj, f) = (f(obj),)

function setall(obj, ::Properties, vs)
names = propertynames(obj)
setproperties(obj, NamedTuple{names}(NTuple{length(names)}(vs)))
end
setall(obj::NamedTuple{NS}, ::Elements, vs) where {NS} = NamedTuple{NS}(NTuple{length(NS)}(vs))
setall(obj::NTuple{N, Any}, ::Elements, vs) where {N} = (@assert length(vs) == N; NTuple{N}(vs))
setall(obj::AbstractArray, ::Elements, vs::AbstractArray) = (@assert length(obj) == length(vs); reshape(vs, size(obj)))
setall(obj::AbstractArray, ::Elements, vs) = setall(obj, Elements(), collect(vs))
setall(obj, o::If, vs) = error("Not supported")
setall(obj, o, vs) = set(obj, o, only(vs))


# implementations for composite optics

# A straightforward recursive approach doesn't actually infer,
# see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68.
# Instead, we need to generate separate functions for each recursion level.

# A recursive implementation of getall doesn't actually infer,
# see https://github.com/JuliaObjects/Accessors.jl/pull/64.
# Instead, we need to generate unrolled code explicitly.
function getall(obj, optic::ComposedFunction)
N = length(decompose(optic))
_GetAll{N}()(obj, optic)
_getall(obj, optic, Val(N))
end

function setall(obj, optic::ComposedFunction, vs)
N = length(decompose(optic))
vss = to_nested_shape(vs, Val(getall_lengths(obj, optic, Val(N))), Val(N))
_setall(obj, optic, vss, Val(N))
end


# _getall: the actual workhorse for getall
_getall(_, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")
_getall(obj, optic, ::Val{1}) = getall(obj, optic)
for i in 2:10
@eval function _getall(obj, optic, ::Val{$i})
_reduce_concat(
map(getall(obj, optic.inner)) do obj
_getall(obj, optic.outer, Val($(i-1)))
end
)
end
end

struct _GetAll{N} end
(::_GetAll{N})(_) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")
# _setall: the actual workhorse for setall
# takes values as a nested tuple with proper leaf lengths, prepared in setall above
_setall(_, _, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/68.")
_setall(obj, optic, vs, ::Val{1}) = setall(obj, optic, vs)
for i in 2:10
@eval function _setall(obj, optic, vs, ::Val{$i})
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
_setall(obj, optic.outer, vss, Val($(i - 1)))
end)
end
end


# helper functions

_concat(a::Tuple, b::Tuple) = (a..., b...)
_concat(a::Tuple, b::AbstractVector) = vcat(collect(a), b)
Expand All @@ -51,26 +99,48 @@ _reduce_concat(xs::AbstractVector) = reduce(append!, xs; init=eltype(eltype(xs))
_reduce_concat(xs::Tuple{AbstractVector, Vararg{AbstractVector}}) = reduce(vcat, xs)
_reduce_concat(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs)

function _generate_getall(N::Int)
syms = [Symbol(:f_, i) for i in 1:N]

expr = :( getall(obj, $(syms[end])) )
for s in syms[1:end - 1] |> reverse
expr = :(
_reduce_concat(
map(getall(obj, $(s))) do obj
$expr
end
)
)
end
_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
_staticlength(x::AbstractVector) = length(x)

:(function (::_GetAll{$N})(obj, optic)
($(syms...),) = deopcompose(optic)
$expr
end)
getall_lengths(obj, optic, ::Val{1}) = _staticlength(getall(obj, optic))
for i in 2:10
@eval function getall_lengths(obj, optic, ::Val{$i})
# convert to Tuple: vectors cannot be put into Val
map(getall(obj, optic.inner) |> Tuple) do o
getall_lengths(o, optic.outer, Val($(i - 1)))
end
end
end

_val(N::Int) = N
_val(::Val{N}) where {N} = N

nestedsum(ls::Int) = ls
nestedsum(ls::Val) = ls
nestedsum(ls::Tuple) = sum(_val ∘ nestedsum, ls)

# to_nested_shape() definition uses both @eval and @generated
#
# @eval is needed because the code for different recursion depths should be different for inference,
# not the same method with different parameters.
#
# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible.
#
# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example.
# That's why it's safe to make it @generated.
to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs)
for i in 2:10
eval(_generate_getall(i))
@eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS}
vi = 1
subs = map(LS) do lss
n = nestedsum(lss)
elems = map(vi:vi+_val(n)-1) do j
:( vs[$j] )
end
res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) )
vi += _val(n)
res
end
:( ($(subs...),) )
end
end
2 changes: 1 addition & 1 deletion src/optics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export @optic
export PropertyLens, IndexLens
export set, modify, delete, insert, getall
export set, modify, delete, insert, getall, setall
export ∘, opcompose, var"⨟"
export Elements, Recursive, If, Properties
export setproperties
Expand Down
16 changes: 16 additions & 0 deletions src/testing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,19 @@ function test_modify_law(f, lens, obj)
obj_setfget = set(obj, lens, val)
@test obj_modify == obj_setfget
end

function test_getsetall_laws(optic, obj, vals1, vals2; cmp=(==))

# setall ⨟ getall
vals = getall(obj, optic)
@test cmp(setall(obj, optic, vals), obj)

# getall ⨟ setall
obj1 = setall(obj, optic, vals1)
@test cmp(collect(getall(obj1, optic)), collect(vals1))

# setall idempotent
obj12 = setall(obj1, optic, vals2)
obj2 = setall(obj12, optic, vals2)
@test obj12 == obj2
end
43 changes: 43 additions & 0 deletions test/test_getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,49 @@ if VERSION >= v"1.6" # for ComposedFunction
obj = ([1, 2], [:a, :b])
@test [1, 2, :a, :b] == @inferred getall(obj, @optic _ |> Elements() |> Elements())
end

@testset "setall" begin
for o in [Elements(), Properties()]
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, (2, 3))
@test (a=2, b="3") === @inferred setall((a=1, b="2"), o, (2, "3"))
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, [2, 3])
end
@test (2, 3) === @inferred setall((1, "2"), Elements(), (2, 3))
@test (2, "3") === @inferred setall((1, "2"), Elements(), (2, "3"))
@test (2, 3) === @inferred setall((1, "2"), Elements(), [2, 3])
@test [2, 3] == @inferred setall([1, "2"], Elements(), (2, 3))
@test [2, "3"] == @inferred setall([1, "2"], Elements(), (2, "3"))
@test [2, 3] == @inferred setall([1, "2"], Elements(), [2, 3])

obj = (a=1, b=2.0, c='3')
@test (a="aa", b=2.0, c='3') === @inferred setall(obj, @optic(_.a), ("aa",))
@test (a=9, b=19.0, c='4') === @inferred setall(obj, @optic(_ |> Elements() |> _ + 1), (10, 20.0, '5'))

obj = (a=1, b=((c=3, d=4), (c=5, d=6)))
@test (a=1, b=(:x, :y)) === @inferred setall(obj, @optic(_.b |> Elements()), (:x, :y))
@test (a=1, b=((c=:x, d=4), (c=:y, d=6))) === @inferred setall(obj, @optic(_.b |> Elements() |> _.c), (:x, :y))
@test (a=1, b=((c=:x, d="y"), (c=:z, d=10))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties()), (:x, "y", :z, 10))
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), (-9, -12, -15, -18))
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), [-9, -12, -15, -18])

obj = ([1, 2], 3:5, (6,))
@test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
# can this infer?..
@test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
end

@testset "getall/setall consistency" begin
for (optic, obj, vals1, vals2) in [
(Elements(), (1, "2"), (2, 3), (4, 5)),
(Properties(), (a=1, b="2"), (2, 3), (4, 5)),
(@optic(_.b |> Elements() |> Properties() |> _ * 3), (a=1, b=((c=3, d=4), (c=5, d=6))), 1:4, (-9, -12, -15, -18)),
]
Accessors.test_getsetall_laws(optic, obj, vals1, vals2)
end
end

end

end