diff --git a/src/getsetall.jl b/src/getsetall.jl index 2b78b989..987e74f2 100644 --- a/src/getsetall.jl +++ b/src/getsetall.jl @@ -84,12 +84,38 @@ setall(obj, o, vs) = OpticStyle(o) == SetBased() ? set(obj, o, only(vs)) : error # implementations for composite optics +# care should be taken for recursion to infer properly +# see https://github.com/FluxML/Functors.jl/pull/61 for the var"#self#" approach and its discussion +function getall(obj, optic::ComposedFunction) + recurse(o, opts) = _walk_getall(var"#self#", o, opts.outer) + _walk_getall(recurse, obj, optic) +end + +_walk_getall(recurse, obj, optics) = optics isa ComposedFunction ? _getall(recurse, obj, optics) : getall(obj, optics) +_getall(recurse, obj, optics) = _map1(recurse, getall(obj, optics.inner), optics) |> _reduce_concat +# any way to infer this without @generated? +@generated function _map1(f, t::NTuple{N,Any}, val) where {N} + :( Base.Cartesian.@ntuple $N i -> f(t[i], val) ) +end +@inline function _map1(f, t, val) + f.(t, Ref(val)) +end + +function getall(obj, or::Recursive) + recurse(o) = _walk_getall_rec(var"#self#", o, or) + _walk_getall_rec(recurse, obj, or) +end +_walk_getall_rec(recurse, obj, or::Recursive) = + if or.descent_condition(obj) + _getall_rec(recurse, obj, or.optic) + else + (obj,) + end +_getall_rec(recurse, obj, optic) = map(recurse, getall(obj, optic)) |> _reduce_concat + # A straightforward recursive approach doesn't actually infer, # see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68. # Instead, we need to generate separate functions for each recursion level. - -getall(obj, optic::ComposedFunction) = _getall(obj, decompose(optic)) - function setall(obj, optic::ComposedFunction, vs) optics = decompose(optic) N = length(optics) @@ -97,19 +123,6 @@ function setall(obj, optic::ComposedFunction, vs) _setall(obj, optics, vss) end - -# _getall: the actual workhorse for getall -_getall(obj, optics::Tuple{Any}) = getall(obj, only(optics)) -for N in [2:10; :(<: Any)] - @eval function _getall(obj, optics::NTuple{$N,Any}) - _reduce_concat( - map(getall(obj, last(optics))) do obj - _getall(obj, Base.front(optics)) - end - ) - end -end - # _setall: the actual workhorse for setall # takes values as a nested tuple with proper leaf lengths, prepared in setall above _setall(obj, optics::Tuple{Any}, vs) = setall(obj, only(optics), vs) diff --git a/test/test_getsetall.jl b/test/test_getsetall.jl index 2d453aad..14b1e679 100644 --- a/test/test_getsetall.jl +++ b/test/test_getsetall.jl @@ -34,7 +34,11 @@ if VERSION >= v"1.6" # for ComposedFunction @test (1, 2, 3, 4, 5, 6) === @inferred getall(obj, @optic _ |> Elements() |> Elements() |> Elements() |> Elements()) @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> Elements() |> Elements() |> Elements() |> Elements() |> _[1]^2 + 1) # maximal supported composition length of 10 optics: - @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> _[:] |> Elements() |> Elements() |> _[:] |> Elements() |> Elements() |> _[1]^2 + 1 |> only) + # @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> _[:] |> Elements() |> Elements() |> _[:] |> Elements() |> Elements() |> _[1]^2 + 1 |> only) + + @test (1, 2, 3, 4, 5, 6) == @inferred getall(obj, Recursive(x->!(x isa Number), Properties())) + @test (3, 4, 5, 6) == @inferred getall(obj, Recursive(x->!(x isa Number), Properties()) ∘ @optic(_[1].b)) + # trickier types for Elements(): obj = (a=("ab", "c"), b=([1 2; 3 4],), c=(SVector(1), SVector(2, 3)))