Skip to content

Commit 165fcc8

Browse files
committed
generate recursion
1 parent 8de8728 commit 165fcc8

File tree

1 file changed

+30
-84
lines changed

1 file changed

+30
-84
lines changed

src/getsetall.jl

Lines changed: 30 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ function getall(obj, optic::ComposedFunction)
4444
_GetAll{N}()(obj, optic)
4545
end
4646

47-
function setall(obj, optic::ComposedFunction, vs::Tuple)
47+
function setall(obj, optic::ComposedFunction, vs)
4848
N = length(decompose(optic))
49-
_SetAll{N}()(obj, optic, vs)
49+
vss = to_nested_shape(vs, typeof(getall_lengths(obj, optic, Val(N))), Val(N))
50+
_setall(obj, optic, vss, Val(N))
5051
end
5152

5253

@@ -90,27 +91,19 @@ end
9091

9192
_staticlength(::Number) = Val(1)
9293
_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
93-
# _staticlength(x::Vector) = length(x)
94+
# _staticlength(x::AbstractVector) = length(x)
9495

9596
_val(::Val{N}) where {N} = N
9697
_val(::Type{Val{N}}) where {N} = N
9798

9899
_staticsum(f, x) = sum(_val f, x) |> Val
99100

100-
getall_lengths(obj, optic) = _staticlength(getall(obj, optic))
101-
function getall_lengths(obj, optic::ComposedFunction, ::Val{2})
102-
map(getall(obj, optic.inner)) do o
103-
getall_lengths(o, optic.outer)
104-
end
105-
end
106-
function getall_lengths(obj, optic::ComposedFunction, ::Val{3})
107-
map(getall(obj, optic.inner)) do o
108-
getall_lengths(o, optic.outer, Val(2))
109-
end
110-
end
111-
function getall_lengths(obj, optic::ComposedFunction, ::Val{4})
112-
map(getall(obj, optic.inner)) do o
113-
getall_lengths(o, optic.outer, Val(3))
101+
getall_lengths(obj, optic, ::Val{1}) = _staticlength(getall(obj, optic))
102+
for i in 2:10
103+
@eval function getall_lengths(obj, optic::ComposedFunction, ::Val{$i})
104+
map(getall(obj, optic.inner)) do o
105+
getall_lengths(o, optic.outer, Val($(i - 1)))
106+
end
114107
end
115108
end
116109

@@ -119,77 +112,30 @@ nestedsum(ls::Type{L}) where {L <: Val} = L
119112
nestedsum(ls::Type{LS}) where {LS <: Tuple} = _staticsum(nestedsum, LS.parameters)
120113

121114

122-
to_nested_shape(vs, ls::Type{LS}) where {LS <: Val} = (@assert length(vs) == _val(LS); vs)
123-
to_nested_shape(vs, ls::LS, VN) where {LS <: Tuple} = to_nested_shape(vs, typeof(ls), VN)
124-
@generated function to_nested_shape(vs, ls::Type{LS}, ::Val{2}) where {LS <: Tuple}
125-
i = 1
126-
subs = map(LS.parameters) do lss
127-
n = nestedsum(lss)
128-
elems = map(i:i+_val(n)-1) do j
129-
:( vs[$j] )
130-
end
131-
res = :( to_nested_shape(($(elems...),), $lss) )
132-
i = i + _val(n)
133-
res
134-
end
135-
:( ($(subs...),) )
136-
end
137-
@generated function to_nested_shape(vs, ls::Type{LS}, ::Val{3}) where {LS <: Tuple}
138-
i = 1
139-
subs = map(LS.parameters) do lss
140-
n = nestedsum(lss)
141-
elems = map(i:i+_val(n)-1) do j
142-
:( vs[$j] )
143-
end
144-
res = :( to_nested_shape(($(elems...),), $lss, Val(2)) )
145-
i = i + _val(n)
146-
res
147-
end
148-
:( ($(subs...),) )
149-
end
150-
@generated function to_nested_shape(vs, ls::Type{LS}, ::Val{4}) where {LS <: Tuple}
151-
i = 1
152-
subs = map(LS.parameters) do lss
153-
n = nestedsum(lss)
154-
elems = map(i:i+_val(n)-1) do j
155-
:( vs[$j] )
115+
to_nested_shape(vs, ls::Type{LS}, ::Val{1}) where {LS <: Val} = (@assert length(vs) == _val(LS); vs)
116+
for i in 2:10
117+
@eval @generated function to_nested_shape(vs, ls::Type{LS}, ::Val{$i}) where {LS <: Tuple}
118+
vi = 1
119+
subs = map(LS.parameters) do lss
120+
n = nestedsum(lss)
121+
elems = map(vi:vi+_val(n)-1) do j
122+
:( vs[$j] )
123+
end
124+
res = :( to_nested_shape(($(elems...),), $lss, $(Val($(i - 1)))) )
125+
vi = vi + _val(n)
126+
res
156127
end
157-
res = :( to_nested_shape(($(elems...),), $lss, Val(3)) )
158-
i = i + _val(n)
159-
res
128+
:( ($(subs...),) )
160129
end
161-
:( ($(subs...),) )
162130
end
163131

164132

165-
_setall(obj, optic, vs) = setall(obj, optic, vs)
166-
_setall(obj, optic::ComposedFunction, vs, ::Val{2}) =
167-
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
168-
_setall(obj, optic.outer, vss)
169-
end)
170-
_setall(obj, optic::ComposedFunction, vs, ::Val{3}) =
171-
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
172-
_setall(obj, optic.outer, vss, Val(2))
173-
end)
174-
_setall(obj, optic::ComposedFunction, vs, ::Val{4}) =
175-
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
176-
_setall(obj, optic.outer, vss, Val(3))
177-
end)
178-
179-
180-
struct _SetAll{N} end
181-
182-
function (::_SetAll{N})(obj, optic, vs) where {N}
183-
vss = to_nested_shape(vs, getall_lengths(obj, optic, Val(N)), Val(N))
184-
# @info "" vs getall_lengths(obj, optic, Val(N)) vss
185-
_setall(obj, optic, vss, Val(N))
133+
_setall(obj, optic, vs, ::Val{1}) = setall(obj, optic, vs)
134+
for i in 2:10
135+
@eval function _setall(obj, optic::ComposedFunction, vs, ::Val{$i})
136+
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
137+
_setall(obj, optic.outer, vss, Val($(i - 1)))
138+
end)
139+
end
186140
end
187141

188-
# split a into two parts: b-sized front and remaining
189-
_split_n(a::NTuple{Na, Any}, b::NTuple{Nb, Any}) where {Na, Nb} = ntuple(i -> a[i], Nb), ntuple(i -> a[Nb + i], Na - Nb)
190-
191-
# split vs into parts sized according to getall(o, f)
192-
_split_getall(ins_old, f, vs) = foldl(ins_old; init=((), vs)) do (acc, vs_), o
193-
vs_cur, vs_rest = _split_n(vs_, getall(o, f))
194-
(acc..., vs_cur), vs_rest
195-
end |> first

0 commit comments

Comments
 (0)