diff --git a/Project.toml b/Project.toml index c5effd03..607ae62b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Accessors" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" authors = ["Takafumi Arakaki ", "Jan Weidner and contributors"] -version = "0.1.21" +version = "0.1.22" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/getsetall.jl b/src/getsetall.jl index c372cb51..486e3522 100644 --- a/src/getsetall.jl +++ b/src/getsetall.jl @@ -3,7 +3,10 @@ Extract all parts of `obj` that are selected by `optic`. Returns a flat `Tuple` of values, or an `AbstractVector` if the selected parts contain arrays. -This function is experimental and we might change the precise output container in future. + +This function is experimental and we might change the precise output container in the future. + +See also [`setall`](@ref). ```jldoctest @@ -20,6 +23,34 @@ julia> getall(obj, @optic _ |> Elements() |> last) """ function getall end +""" + setall(obj, optic, values) + +Replace a part of `obj` that is selected by `optic` with `values`. +The `values` collection should have the same number of elements as selected by `optic`. + +This function is experimental and might change in the future. + +See also [`getall`](@ref), [`set`](@ref). The former is dual to `setall`: + +```jldoctest +julia> using Accessors + +julia> obj = (a=1, b=(2, 3)); + +julia> optic = @optic _ |> Elements() |> last; + +julia> getall(obj, optic) +(1, 3) + +julia> setall(obj, optic, (4, 5)) +(a = 4, b = (2, 5)) +``` +""" +function setall end + +# implementations for individual noncomposite optics + getall(obj::Union{Tuple, AbstractVector}, ::Elements) = obj getall(obj::Union{NamedTuple}, ::Elements) = values(obj) getall(obj::AbstractArray, ::Elements) = vec(obj) @@ -29,17 +60,63 @@ getall(obj, ::Properties) = getproperties(obj) |> values getall(obj, o::If) = o.modify_condition(obj) ? (obj,) : () getall(obj, f) = (f(obj),) +function setall(obj, ::Properties, vs) + names = propertynames(obj) + setproperties(obj, NamedTuple{names}(NTuple{length(names)}(vs))) +end +setall(obj::NamedTuple{NS}, ::Elements, vs) where {NS} = NamedTuple{NS}(NTuple{length(NS)}(vs)) +setall(obj::NTuple{N, Any}, ::Elements, vs) where {N} = (@assert length(vs) == N; NTuple{N}(vs)) +setall(obj::AbstractArray, ::Elements, vs::AbstractArray) = (@assert length(obj) == length(vs); reshape(vs, size(obj))) +setall(obj::AbstractArray, ::Elements, vs) = setall(obj, Elements(), collect(vs)) +setall(obj, o::If, vs) = error("Not supported") +setall(obj, o, vs) = set(obj, o, only(vs)) + + +# implementations for composite optics + +# A straightforward recursive approach doesn't actually infer, +# see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68. +# Instead, we need to generate separate functions for each recursion level. -# A recursive implementation of getall doesn't actually infer, -# see https://github.com/JuliaObjects/Accessors.jl/pull/64. -# Instead, we need to generate unrolled code explicitly. function getall(obj, optic::ComposedFunction) N = length(decompose(optic)) - _GetAll{N}()(obj, optic) + _getall(obj, optic, Val(N)) +end + +function setall(obj, optic::ComposedFunction, vs) + N = length(decompose(optic)) + vss = to_nested_shape(vs, Val(getall_lengths(obj, optic, Val(N))), Val(N)) + _setall(obj, optic, vss, Val(N)) +end + + +# _getall: the actual workhorse for getall +_getall(_, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.") +_getall(obj, optic, ::Val{1}) = getall(obj, optic) +for i in 2:10 + @eval function _getall(obj, optic, ::Val{$i}) + _reduce_concat( + map(getall(obj, optic.inner)) do obj + _getall(obj, optic.outer, Val($(i-1))) + end + ) + end +end + +# _setall: the actual workhorse for setall +# takes values as a nested tuple with proper leaf lengths, prepared in setall above +_setall(_, _, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/68.") +_setall(obj, optic, vs, ::Val{1}) = setall(obj, optic, vs) +for i in 2:10 + @eval function _setall(obj, optic, vs, ::Val{$i}) + setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss + _setall(obj, optic.outer, vss, Val($(i - 1))) + end) + end end -struct _GetAll{N} end -(::_GetAll{N})(_) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.") + +# helper functions _concat(a::Tuple, b::Tuple) = (a..., b...) _concat(a::Tuple, b::AbstractVector) = vcat(collect(a), b) @@ -51,26 +128,48 @@ _reduce_concat(xs::AbstractVector) = reduce(append!, xs; init=eltype(eltype(xs)) _reduce_concat(xs::Tuple{AbstractVector, Vararg{AbstractVector}}) = reduce(vcat, xs) _reduce_concat(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs) -function _generate_getall(N::Int) - syms = [Symbol(:f_, i) for i in 1:N] - - expr = :( getall(obj, $(syms[end])) ) - for s in syms[1:end - 1] |> reverse - expr = :( - _reduce_concat( - map(getall(obj, $(s))) do obj - $expr - end - ) - ) - end +_staticlength(::NTuple{N, <:Any}) where {N} = Val(N) +_staticlength(x::AbstractVector) = length(x) - :(function (::_GetAll{$N})(obj, optic) - ($(syms...),) = deopcompose(optic) - $expr - end) +getall_lengths(obj, optic, ::Val{1}) = _staticlength(getall(obj, optic)) +for i in 2:10 + @eval function getall_lengths(obj, optic, ::Val{$i}) + # convert to Tuple: vectors cannot be put into Val + map(getall(obj, optic.inner) |> Tuple) do o + getall_lengths(o, optic.outer, Val($(i - 1))) + end + end end +_val(N::Int) = N +_val(::Val{N}) where {N} = N + +nestedsum(ls::Int) = ls +nestedsum(ls::Val) = ls +nestedsum(ls::Tuple) = sum(_val ∘ nestedsum, ls) + +# to_nested_shape() definition uses both @eval and @generated +# +# @eval is needed because the code for different recursion depths should be different for inference, +# not the same method with different parameters. +# +# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible. +# +# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example. +# That's why it's safe to make it @generated. +to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs) for i in 2:10 - eval(_generate_getall(i)) + @eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS} + vi = 1 + subs = map(LS) do lss + n = nestedsum(lss) + elems = map(vi:vi+_val(n)-1) do j + :( vs[$j] ) + end + res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) ) + vi += _val(n) + res + end + :( ($(subs...),) ) + end end diff --git a/src/optics.jl b/src/optics.jl index 782c7505..d0e6e0ed 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -1,6 +1,6 @@ export @optic export PropertyLens, IndexLens -export set, modify, delete, insert, getall +export set, modify, delete, insert, getall, setall export ∘, opcompose, var"⨟" export Elements, Recursive, If, Properties export setproperties diff --git a/src/testing.jl b/src/testing.jl index 6281390e..c74847d4 100644 --- a/src/testing.jl +++ b/src/testing.jl @@ -22,3 +22,19 @@ function test_modify_law(f, lens, obj) obj_setfget = set(obj, lens, val) @test obj_modify == obj_setfget end + +function test_getsetall_laws(optic, obj, vals1, vals2; cmp=(==)) + + # setall ⨟ getall + vals = getall(obj, optic) + @test cmp(setall(obj, optic, vals), obj) + + # getall ⨟ setall + obj1 = setall(obj, optic, vals1) + @test cmp(collect(getall(obj1, optic)), collect(vals1)) + + # setall idempotent + obj12 = setall(obj1, optic, vals2) + obj2 = setall(obj12, optic, vals2) + @test obj12 == obj2 +end diff --git a/test/test_getsetall.jl b/test/test_getsetall.jl index 4291d155..baa6b888 100644 --- a/test/test_getsetall.jl +++ b/test/test_getsetall.jl @@ -64,6 +64,49 @@ if VERSION >= v"1.6" # for ComposedFunction obj = ([1, 2], [:a, :b]) @test [1, 2, :a, :b] == @inferred getall(obj, @optic _ |> Elements() |> Elements()) end + +@testset "setall" begin + for o in [Elements(), Properties()] + @test (a=2, b=3) === @inferred setall((a=1, b="2"), o, (2, 3)) + @test (a=2, b="3") === @inferred setall((a=1, b="2"), o, (2, "3")) + @test (a=2, b=3) === @inferred setall((a=1, b="2"), o, [2, 3]) + end + @test (2, 3) === @inferred setall((1, "2"), Elements(), (2, 3)) + @test (2, "3") === @inferred setall((1, "2"), Elements(), (2, "3")) + @test (2, 3) === @inferred setall((1, "2"), Elements(), [2, 3]) + @test [2, 3] == @inferred setall([1, "2"], Elements(), (2, 3)) + @test [2, "3"] == @inferred setall([1, "2"], Elements(), (2, "3")) + @test [2, 3] == @inferred setall([1, "2"], Elements(), [2, 3]) + + obj = (a=1, b=2.0, c='3') + @test (a="aa", b=2.0, c='3') === @inferred setall(obj, @optic(_.a), ("aa",)) + @test (a=9, b=19.0, c='4') === @inferred setall(obj, @optic(_ |> Elements() |> _ + 1), (10, 20.0, '5')) + + obj = (a=1, b=((c=3, d=4), (c=5, d=6))) + @test (a=1, b=(:x, :y)) === @inferred setall(obj, @optic(_.b |> Elements()), (:x, :y)) + @test (a=1, b=((c=:x, d=4), (c=:y, d=6))) === @inferred setall(obj, @optic(_.b |> Elements() |> _.c), (:x, :y)) + @test (a=1, b=((c=:x, d="y"), (c=:z, d=10))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties()), (:x, "y", :z, 10)) + @test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), (-9, -12, -15, -18)) + @test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), [-9, -12, -15, -18]) + + obj = ([1, 2], 3:5, (6,)) + @test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6) + @test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6) + # can this infer?.. + @test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6) + @test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6) +end + +@testset "getall/setall consistency" begin + for (optic, obj, vals1, vals2) in [ + (Elements(), (1, "2"), (2, 3), (4, 5)), + (Properties(), (a=1, b="2"), (2, 3), (4, 5)), + (@optic(_.b |> Elements() |> Properties() |> _ * 3), (a=1, b=((c=3, d=4), (c=5, d=6))), 1:4, (-9, -12, -15, -18)), + ] + Accessors.test_getsetall_laws(optic, obj, vals1, vals2) + end +end + end end