From 684ec04e2a8c90127da817e2a68752dbbea3776c Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Fri, 31 Mar 2023 20:21:22 +0300 Subject: [PATCH 1/2] getall: remove @eval --- src/getsetall.jl | 33 +++++++++++++++++---------------- test/test_getsetall.jl | 2 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/getsetall.jl b/src/getsetall.jl index 2b78b989..3c0acfda 100644 --- a/src/getsetall.jl +++ b/src/getsetall.jl @@ -84,12 +84,26 @@ setall(obj, o, vs) = OpticStyle(o) == SetBased() ? set(obj, o, only(vs)) : error # implementations for composite optics +# care should be taken for recursion to infer properly +# see https://github.com/FluxML/Functors.jl/pull/61 for the var"#self#" approach and its discussion +function getall(obj, optic::ComposedFunction) + recurse(o, opts) = _walk_getall(var"#self#", o, opts.outer) + _walk_getall(recurse, obj, optic) +end + +_walk_getall(recurse, obj, optics) = optics isa ComposedFunction ? _getall(recurse, obj, optics) : getall(obj, optics) +_getall(recurse, obj, optics) = _map1(recurse, getall(obj, optics.inner), optics) |> _reduce_concat +# any way to infer this without @generated? +@generated function _map1(f, t::NTuple{N,Any}, val) where {N} + :( Base.Cartesian.@ntuple $N i -> f(t[i], val) ) +end +@inline function _map1(f, t, val) + f.(t, Ref(val)) +end + # A straightforward recursive approach doesn't actually infer, # see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68. # Instead, we need to generate separate functions for each recursion level. - -getall(obj, optic::ComposedFunction) = _getall(obj, decompose(optic)) - function setall(obj, optic::ComposedFunction, vs) optics = decompose(optic) N = length(optics) @@ -97,19 +111,6 @@ function setall(obj, optic::ComposedFunction, vs) _setall(obj, optics, vss) end - -# _getall: the actual workhorse for getall -_getall(obj, optics::Tuple{Any}) = getall(obj, only(optics)) -for N in [2:10; :(<: Any)] - @eval function _getall(obj, optics::NTuple{$N,Any}) - _reduce_concat( - map(getall(obj, last(optics))) do obj - _getall(obj, Base.front(optics)) - end - ) - end -end - # _setall: the actual workhorse for setall # takes values as a nested tuple with proper leaf lengths, prepared in setall above _setall(obj, optics::Tuple{Any}, vs) = setall(obj, only(optics), vs) diff --git a/test/test_getsetall.jl b/test/test_getsetall.jl index 2d453aad..adae43f5 100644 --- a/test/test_getsetall.jl +++ b/test/test_getsetall.jl @@ -34,7 +34,7 @@ if VERSION >= v"1.6" # for ComposedFunction @test (1, 2, 3, 4, 5, 6) === @inferred getall(obj, @optic _ |> Elements() |> Elements() |> Elements() |> Elements()) @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> Elements() |> Elements() |> Elements() |> Elements() |> _[1]^2 + 1) # maximal supported composition length of 10 optics: - @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> _[:] |> Elements() |> Elements() |> _[:] |> Elements() |> Elements() |> _[1]^2 + 1 |> only) + # @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> _[:] |> Elements() |> Elements() |> _[:] |> Elements() |> Elements() |> _[1]^2 + 1 |> only) # trickier types for Elements(): obj = (a=("ab", "c"), b=([1 2; 3 4],), c=(SVector(1), SVector(2, 3))) From b154fa02ae7c06941ba28db4ceaf6d01b9d5a8a2 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Fri, 31 Mar 2023 20:21:32 +0300 Subject: [PATCH 2/2] getall: support Recursive --- src/getsetall.jl | 12 ++++++++++++ test/test_getsetall.jl | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/src/getsetall.jl b/src/getsetall.jl index 3c0acfda..987e74f2 100644 --- a/src/getsetall.jl +++ b/src/getsetall.jl @@ -101,6 +101,18 @@ end f.(t, Ref(val)) end +function getall(obj, or::Recursive) + recurse(o) = _walk_getall_rec(var"#self#", o, or) + _walk_getall_rec(recurse, obj, or) +end +_walk_getall_rec(recurse, obj, or::Recursive) = + if or.descent_condition(obj) + _getall_rec(recurse, obj, or.optic) + else + (obj,) + end +_getall_rec(recurse, obj, optic) = map(recurse, getall(obj, optic)) |> _reduce_concat + # A straightforward recursive approach doesn't actually infer, # see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68. # Instead, we need to generate separate functions for each recursion level. diff --git a/test/test_getsetall.jl b/test/test_getsetall.jl index adae43f5..14b1e679 100644 --- a/test/test_getsetall.jl +++ b/test/test_getsetall.jl @@ -36,6 +36,10 @@ if VERSION >= v"1.6" # for ComposedFunction # maximal supported composition length of 10 optics: # @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> _[:] |> Elements() |> Elements() |> _[:] |> Elements() |> Elements() |> _[1]^2 + 1 |> only) + @test (1, 2, 3, 4, 5, 6) == @inferred getall(obj, Recursive(x->!(x isa Number), Properties())) + @test (3, 4, 5, 6) == @inferred getall(obj, Recursive(x->!(x isa Number), Properties()) ∘ @optic(_[1].b)) + + # trickier types for Elements(): obj = (a=("ab", "c"), b=([1 2; 3 4],), c=(SVector(1), SVector(2, 3))) @test ['b', 'c', 'd'] == @inferred getall(obj, @optic _.a |> Elements() |> Elements() |> _ + 1)