Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Accessors"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
authors = ["Takafumi Arakaki <[email protected]>", "Jan Weidner <[email protected]> and contributors"]
version = "0.1.21"
version = "0.1.22"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
149 changes: 124 additions & 25 deletions src/getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

Extract all parts of `obj` that are selected by `optic`.
Returns a flat `Tuple` of values, or an `AbstractVector` if the selected parts contain arrays.
This function is experimental and we might change the precise output container in future.

This function is experimental and we might change the precise output container in the future.

See also [`setall`](@ref).


```jldoctest
Expand All @@ -20,6 +23,34 @@ julia> getall(obj, @optic _ |> Elements() |> last)
"""
function getall end

"""
setall(obj, optic, values)

Replace a part of `obj` that is selected by `optic` with `values`.
The `values` collection should have the same number of elements as selected by `optic`.

This function is experimental and might change in the future.

See also [`getall`](@ref), [`set`](@ref). The former is dual to `setall`:

```jldoctest
julia> using Accessors

julia> obj = (a=1, b=(2, 3));

julia> optic = @optic _ |> Elements() |> last;

julia> getall(obj, optic)
(1, 3)

julia> setall(obj, optic, (4, 5))
(a = 4, b = (2, 5))
```
"""
function setall end

# implementations for individual noncomposite optics

getall(obj::Union{Tuple, AbstractVector}, ::Elements) = obj
getall(obj::Union{NamedTuple}, ::Elements) = values(obj)
getall(obj::AbstractArray, ::Elements) = vec(obj)
Expand All @@ -29,17 +60,63 @@ getall(obj, ::Properties) = getproperties(obj) |> values
getall(obj, o::If) = o.modify_condition(obj) ? (obj,) : ()
getall(obj, f) = (f(obj),)

function setall(obj, ::Properties, vs)
names = propertynames(obj)
setproperties(obj, NamedTuple{names}(NTuple{length(names)}(vs)))
end
setall(obj::NamedTuple{NS}, ::Elements, vs) where {NS} = NamedTuple{NS}(NTuple{length(NS)}(vs))
setall(obj::NTuple{N, Any}, ::Elements, vs) where {N} = (@assert length(vs) == N; NTuple{N}(vs))
setall(obj::AbstractArray, ::Elements, vs::AbstractArray) = (@assert length(obj) == length(vs); reshape(vs, size(obj)))
setall(obj::AbstractArray, ::Elements, vs) = setall(obj, Elements(), collect(vs))
setall(obj, o::If, vs) = error("Not supported")
setall(obj, o, vs) = set(obj, o, only(vs))


# implementations for composite optics

# A straightforward recursive approach doesn't actually infer,
# see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68.
# Instead, we need to generate separate functions for each recursion level.

# A recursive implementation of getall doesn't actually infer,
# see https://github.com/JuliaObjects/Accessors.jl/pull/64.
# Instead, we need to generate unrolled code explicitly.
function getall(obj, optic::ComposedFunction)
N = length(decompose(optic))
_GetAll{N}()(obj, optic)
_getall(obj, optic, Val(N))
end

function setall(obj, optic::ComposedFunction, vs)
N = length(decompose(optic))
vss = to_nested_shape(vs, Val(getall_lengths(obj, optic, Val(N))), Val(N))
_setall(obj, optic, vss, Val(N))
end


# _getall: the actual workhorse for getall
_getall(_, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")
_getall(obj, optic, ::Val{1}) = getall(obj, optic)
for i in 2:10
@eval function _getall(obj, optic, ::Val{$i})
_reduce_concat(
map(getall(obj, optic.inner)) do obj
_getall(obj, optic.outer, Val($(i-1)))
end
)
end
end

# _setall: the actual workhorse for setall
# takes values as a nested tuple with proper leaf lengths, prepared in setall above
_setall(_, _, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/68.")
_setall(obj, optic, vs, ::Val{1}) = setall(obj, optic, vs)
for i in 2:10
@eval function _setall(obj, optic, vs, ::Val{$i})
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
_setall(obj, optic.outer, vss, Val($(i - 1)))
end)
end
end

struct _GetAll{N} end
(::_GetAll{N})(_) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")

# helper functions

_concat(a::Tuple, b::Tuple) = (a..., b...)
_concat(a::Tuple, b::AbstractVector) = vcat(collect(a), b)
Expand All @@ -51,26 +128,48 @@ _reduce_concat(xs::AbstractVector) = reduce(append!, xs; init=eltype(eltype(xs))
_reduce_concat(xs::Tuple{AbstractVector, Vararg{AbstractVector}}) = reduce(vcat, xs)
_reduce_concat(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs)

function _generate_getall(N::Int)
syms = [Symbol(:f_, i) for i in 1:N]

expr = :( getall(obj, $(syms[end])) )
for s in syms[1:end - 1] |> reverse
expr = :(
_reduce_concat(
map(getall(obj, $(s))) do obj
$expr
end
)
)
end
_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
_staticlength(x::AbstractVector) = length(x)

:(function (::_GetAll{$N})(obj, optic)
($(syms...),) = deopcompose(optic)
$expr
end)
getall_lengths(obj, optic, ::Val{1}) = _staticlength(getall(obj, optic))
for i in 2:10
@eval function getall_lengths(obj, optic, ::Val{$i})
# convert to Tuple: vectors cannot be put into Val
map(getall(obj, optic.inner) |> Tuple) do o
getall_lengths(o, optic.outer, Val($(i - 1)))
end
end
end

_val(N::Int) = N
_val(::Val{N}) where {N} = N

nestedsum(ls::Int) = ls
nestedsum(ls::Val) = ls
nestedsum(ls::Tuple) = sum(_val ∘ nestedsum, ls)

# to_nested_shape() definition uses both @eval and @generated
#
# @eval is needed because the code for different recursion depths should be different for inference,
# not the same method with different parameters.
#
# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible.
#
# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example.
# That's why it's safe to make it @generated.
to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs)
for i in 2:10
eval(_generate_getall(i))
@eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS}
vi = 1
subs = map(LS) do lss
n = nestedsum(lss)
elems = map(vi:vi+_val(n)-1) do j
:( vs[$j] )
end
res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) )
vi += _val(n)
res
end
:( ($(subs...),) )
end
end
2 changes: 1 addition & 1 deletion src/optics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export @optic
export PropertyLens, IndexLens
export set, modify, delete, insert, getall
export set, modify, delete, insert, getall, setall
export ∘, opcompose, var"⨟"
export Elements, Recursive, If, Properties
export setproperties
Expand Down
16 changes: 16 additions & 0 deletions src/testing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,19 @@ function test_modify_law(f, lens, obj)
obj_setfget = set(obj, lens, val)
@test obj_modify == obj_setfget
end

function test_getsetall_laws(optic, obj, vals1, vals2; cmp=(==))

# setall ⨟ getall
vals = getall(obj, optic)
@test cmp(setall(obj, optic, vals), obj)

# getall ⨟ setall
obj1 = setall(obj, optic, vals1)
@test cmp(collect(getall(obj1, optic)), collect(vals1))

# setall idempotent
obj12 = setall(obj1, optic, vals2)
obj2 = setall(obj12, optic, vals2)
@test obj12 == obj2
end
43 changes: 43 additions & 0 deletions test/test_getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,49 @@ if VERSION >= v"1.6" # for ComposedFunction
obj = ([1, 2], [:a, :b])
@test [1, 2, :a, :b] == @inferred getall(obj, @optic _ |> Elements() |> Elements())
end

@testset "setall" begin
for o in [Elements(), Properties()]
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, (2, 3))
@test (a=2, b="3") === @inferred setall((a=1, b="2"), o, (2, "3"))
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, [2, 3])
end
@test (2, 3) === @inferred setall((1, "2"), Elements(), (2, 3))
@test (2, "3") === @inferred setall((1, "2"), Elements(), (2, "3"))
@test (2, 3) === @inferred setall((1, "2"), Elements(), [2, 3])
@test [2, 3] == @inferred setall([1, "2"], Elements(), (2, 3))
@test [2, "3"] == @inferred setall([1, "2"], Elements(), (2, "3"))
@test [2, 3] == @inferred setall([1, "2"], Elements(), [2, 3])

obj = (a=1, b=2.0, c='3')
@test (a="aa", b=2.0, c='3') === @inferred setall(obj, @optic(_.a), ("aa",))
@test (a=9, b=19.0, c='4') === @inferred setall(obj, @optic(_ |> Elements() |> _ + 1), (10, 20.0, '5'))

obj = (a=1, b=((c=3, d=4), (c=5, d=6)))
@test (a=1, b=(:x, :y)) === @inferred setall(obj, @optic(_.b |> Elements()), (:x, :y))
@test (a=1, b=((c=:x, d=4), (c=:y, d=6))) === @inferred setall(obj, @optic(_.b |> Elements() |> _.c), (:x, :y))
@test (a=1, b=((c=:x, d="y"), (c=:z, d=10))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties()), (:x, "y", :z, 10))
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), (-9, -12, -15, -18))
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), [-9, -12, -15, -18])

obj = ([1, 2], 3:5, (6,))
@test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
# can this infer?..
@test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
end

@testset "getall/setall consistency" begin
for (optic, obj, vals1, vals2) in [
(Elements(), (1, "2"), (2, 3), (4, 5)),
(Properties(), (a=1, b="2"), (2, 3), (4, 5)),
(@optic(_.b |> Elements() |> Properties() |> _ * 3), (a=1, b=((c=3, d=4), (c=5, d=6))), 1:4, (-9, -12, -15, -18)),
]
Accessors.test_getsetall_laws(optic, obj, vals1, vals2)
end
end

end

end