Skip to content

Commit a82ed28

Browse files
authored
add setall() (#68)
1 parent 1139a77 commit a82ed28

File tree

5 files changed

+185
-27
lines changed

5 files changed

+185
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Accessors"
22
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
33
authors = ["Takafumi Arakaki <[email protected]>", "Jan Weidner <[email protected]> and contributors"]
4-
version = "0.1.21"
4+
version = "0.1.22"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/getsetall.jl

Lines changed: 124 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
44
Extract all parts of `obj` that are selected by `optic`.
55
Returns a flat `Tuple` of values, or an `AbstractVector` if the selected parts contain arrays.
6-
This function is experimental and we might change the precise output container in future.
6+
7+
This function is experimental and we might change the precise output container in the future.
8+
9+
See also [`setall`](@ref).
710
811
912
```jldoctest
@@ -20,6 +23,34 @@ julia> getall(obj, @optic _ |> Elements() |> last)
2023
"""
2124
function getall end
2225

26+
"""
27+
setall(obj, optic, values)
28+
29+
Replace a part of `obj` that is selected by `optic` with `values`.
30+
The `values` collection should have the same number of elements as selected by `optic`.
31+
32+
This function is experimental and might change in the future.
33+
34+
See also [`getall`](@ref), [`set`](@ref). The former is dual to `setall`:
35+
36+
```jldoctest
37+
julia> using Accessors
38+
39+
julia> obj = (a=1, b=(2, 3));
40+
41+
julia> optic = @optic _ |> Elements() |> last;
42+
43+
julia> getall(obj, optic)
44+
(1, 3)
45+
46+
julia> setall(obj, optic, (4, 5))
47+
(a = 4, b = (2, 5))
48+
```
49+
"""
50+
function setall end
51+
52+
# implementations for individual noncomposite optics
53+
2354
getall(obj::Union{Tuple, AbstractVector}, ::Elements) = obj
2455
getall(obj::Union{NamedTuple}, ::Elements) = values(obj)
2556
getall(obj::AbstractArray, ::Elements) = vec(obj)
@@ -29,17 +60,63 @@ getall(obj, ::Properties) = getproperties(obj) |> values
2960
getall(obj, o::If) = o.modify_condition(obj) ? (obj,) : ()
3061
getall(obj, f) = (f(obj),)
3162

63+
function setall(obj, ::Properties, vs)
64+
names = propertynames(obj)
65+
setproperties(obj, NamedTuple{names}(NTuple{length(names)}(vs)))
66+
end
67+
setall(obj::NamedTuple{NS}, ::Elements, vs) where {NS} = NamedTuple{NS}(NTuple{length(NS)}(vs))
68+
setall(obj::NTuple{N, Any}, ::Elements, vs) where {N} = (@assert length(vs) == N; NTuple{N}(vs))
69+
setall(obj::AbstractArray, ::Elements, vs::AbstractArray) = (@assert length(obj) == length(vs); reshape(vs, size(obj)))
70+
setall(obj::AbstractArray, ::Elements, vs) = setall(obj, Elements(), collect(vs))
71+
setall(obj, o::If, vs) = error("Not supported")
72+
setall(obj, o, vs) = set(obj, o, only(vs))
73+
74+
75+
# implementations for composite optics
76+
77+
# A straightforward recursive approach doesn't actually infer,
78+
# see https://github.com/JuliaObjects/Accessors.jl/pull/64 and https://github.com/JuliaObjects/Accessors.jl/pull/68.
79+
# Instead, we need to generate separate functions for each recursion level.
3280

33-
# A recursive implementation of getall doesn't actually infer,
34-
# see https://github.com/JuliaObjects/Accessors.jl/pull/64.
35-
# Instead, we need to generate unrolled code explicitly.
3681
function getall(obj, optic::ComposedFunction)
3782
N = length(decompose(optic))
38-
_GetAll{N}()(obj, optic)
83+
_getall(obj, optic, Val(N))
84+
end
85+
86+
function setall(obj, optic::ComposedFunction, vs)
87+
N = length(decompose(optic))
88+
vss = to_nested_shape(vs, Val(getall_lengths(obj, optic, Val(N))), Val(N))
89+
_setall(obj, optic, vss, Val(N))
90+
end
91+
92+
93+
# _getall: the actual workhorse for getall
94+
_getall(_, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")
95+
_getall(obj, optic, ::Val{1}) = getall(obj, optic)
96+
for i in 2:10
97+
@eval function _getall(obj, optic, ::Val{$i})
98+
_reduce_concat(
99+
map(getall(obj, optic.inner)) do obj
100+
_getall(obj, optic.outer, Val($(i-1)))
101+
end
102+
)
103+
end
104+
end
105+
106+
# _setall: the actual workhorse for setall
107+
# takes values as a nested tuple with proper leaf lengths, prepared in setall above
108+
_setall(_, _, _, ::Val{N}) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/68.")
109+
_setall(obj, optic, vs, ::Val{1}) = setall(obj, optic, vs)
110+
for i in 2:10
111+
@eval function _setall(obj, optic, vs, ::Val{$i})
112+
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
113+
_setall(obj, optic.outer, vss, Val($(i - 1)))
114+
end)
115+
end
39116
end
40117

41-
struct _GetAll{N} end
42-
(::_GetAll{N})(_) where {N} = error("Too many chained optics: $N is not supported for now. See also https://github.com/JuliaObjects/Accessors.jl/pull/64.")
118+
119+
# helper functions
43120

44121
_concat(a::Tuple, b::Tuple) = (a..., b...)
45122
_concat(a::Tuple, b::AbstractVector) = vcat(collect(a), b)
@@ -51,26 +128,48 @@ _reduce_concat(xs::AbstractVector) = reduce(append!, xs; init=eltype(eltype(xs))
51128
_reduce_concat(xs::Tuple{AbstractVector, Vararg{AbstractVector}}) = reduce(vcat, xs)
52129
_reduce_concat(xs::AbstractVector{<:AbstractVector}) = reduce(vcat, xs)
53130

54-
function _generate_getall(N::Int)
55-
syms = [Symbol(:f_, i) for i in 1:N]
56-
57-
expr = :( getall(obj, $(syms[end])) )
58-
for s in syms[1:end - 1] |> reverse
59-
expr = :(
60-
_reduce_concat(
61-
map(getall(obj, $(s))) do obj
62-
$expr
63-
end
64-
)
65-
)
66-
end
131+
_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
132+
_staticlength(x::AbstractVector) = length(x)
67133

68-
:(function (::_GetAll{$N})(obj, optic)
69-
($(syms...),) = deopcompose(optic)
70-
$expr
71-
end)
134+
getall_lengths(obj, optic, ::Val{1}) = _staticlength(getall(obj, optic))
135+
for i in 2:10
136+
@eval function getall_lengths(obj, optic, ::Val{$i})
137+
# convert to Tuple: vectors cannot be put into Val
138+
map(getall(obj, optic.inner) |> Tuple) do o
139+
getall_lengths(o, optic.outer, Val($(i - 1)))
140+
end
141+
end
72142
end
73143

144+
_val(N::Int) = N
145+
_val(::Val{N}) where {N} = N
146+
147+
nestedsum(ls::Int) = ls
148+
nestedsum(ls::Val) = ls
149+
nestedsum(ls::Tuple) = sum(_val nestedsum, ls)
150+
151+
# to_nested_shape() definition uses both @eval and @generated
152+
#
153+
# @eval is needed because the code for different recursion depths should be different for inference,
154+
# not the same method with different parameters.
155+
#
156+
# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible.
157+
#
158+
# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example.
159+
# That's why it's safe to make it @generated.
160+
to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs)
74161
for i in 2:10
75-
eval(_generate_getall(i))
162+
@eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS}
163+
vi = 1
164+
subs = map(LS) do lss
165+
n = nestedsum(lss)
166+
elems = map(vi:vi+_val(n)-1) do j
167+
:( vs[$j] )
168+
end
169+
res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) )
170+
vi += _val(n)
171+
res
172+
end
173+
:( ($(subs...),) )
174+
end
76175
end

src/optics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export @optic
22
export PropertyLens, IndexLens
3-
export set, modify, delete, insert, getall
3+
export set, modify, delete, insert, getall, setall
44
export , opcompose, var"⨟"
55
export Elements, Recursive, If, Properties
66
export setproperties

src/testing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,19 @@ function test_modify_law(f, lens, obj)
2222
obj_setfget = set(obj, lens, val)
2323
@test obj_modify == obj_setfget
2424
end
25+
26+
function test_getsetall_laws(optic, obj, vals1, vals2; cmp=(==))
27+
28+
# setall ⨟ getall
29+
vals = getall(obj, optic)
30+
@test cmp(setall(obj, optic, vals), obj)
31+
32+
# getall ⨟ setall
33+
obj1 = setall(obj, optic, vals1)
34+
@test cmp(collect(getall(obj1, optic)), collect(vals1))
35+
36+
# setall idempotent
37+
obj12 = setall(obj1, optic, vals2)
38+
obj2 = setall(obj12, optic, vals2)
39+
@test obj12 == obj2
40+
end

test/test_getsetall.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,49 @@ if VERSION >= v"1.6" # for ComposedFunction
6464
obj = ([1, 2], [:a, :b])
6565
@test [1, 2, :a, :b] == @inferred getall(obj, @optic _ |> Elements() |> Elements())
6666
end
67+
68+
@testset "setall" begin
69+
for o in [Elements(), Properties()]
70+
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, (2, 3))
71+
@test (a=2, b="3") === @inferred setall((a=1, b="2"), o, (2, "3"))
72+
@test (a=2, b=3) === @inferred setall((a=1, b="2"), o, [2, 3])
73+
end
74+
@test (2, 3) === @inferred setall((1, "2"), Elements(), (2, 3))
75+
@test (2, "3") === @inferred setall((1, "2"), Elements(), (2, "3"))
76+
@test (2, 3) === @inferred setall((1, "2"), Elements(), [2, 3])
77+
@test [2, 3] == @inferred setall([1, "2"], Elements(), (2, 3))
78+
@test [2, "3"] == @inferred setall([1, "2"], Elements(), (2, "3"))
79+
@test [2, 3] == @inferred setall([1, "2"], Elements(), [2, 3])
80+
81+
obj = (a=1, b=2.0, c='3')
82+
@test (a="aa", b=2.0, c='3') === @inferred setall(obj, @optic(_.a), ("aa",))
83+
@test (a=9, b=19.0, c='4') === @inferred setall(obj, @optic(_ |> Elements() |> _ + 1), (10, 20.0, '5'))
84+
85+
obj = (a=1, b=((c=3, d=4), (c=5, d=6)))
86+
@test (a=1, b=(:x, :y)) === @inferred setall(obj, @optic(_.b |> Elements()), (:x, :y))
87+
@test (a=1, b=((c=:x, d=4), (c=:y, d=6))) === @inferred setall(obj, @optic(_.b |> Elements() |> _.c), (:x, :y))
88+
@test (a=1, b=((c=:x, d="y"), (c=:z, d=10))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties()), (:x, "y", :z, 10))
89+
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), (-9, -12, -15, -18))
90+
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), [-9, -12, -15, -18])
91+
92+
obj = ([1, 2], 3:5, (6,))
93+
@test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
94+
@test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
95+
# can this infer?..
96+
@test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
97+
@test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
98+
end
99+
100+
@testset "getall/setall consistency" begin
101+
for (optic, obj, vals1, vals2) in [
102+
(Elements(), (1, "2"), (2, 3), (4, 5)),
103+
(Properties(), (a=1, b="2"), (2, 3), (4, 5)),
104+
(@optic(_.b |> Elements() |> Properties() |> _ * 3), (a=1, b=((c=3, d=4), (c=5, d=6))), 1:4, (-9, -12, -15, -18)),
105+
]
106+
Accessors.test_getsetall_laws(optic, obj, vals1, vals2)
107+
end
108+
end
109+
67110
end
68111

69112
end

0 commit comments

Comments
 (0)