Skip to content

Commit 8de8728

Browse files
committed
manual recursion
1 parent aee5552 commit 8de8728

File tree

1 file changed

+61
-65
lines changed

1 file changed

+61
-65
lines changed

src/getsetall.jl

Lines changed: 61 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ function getall(obj, optic::ComposedFunction)
4444
_GetAll{N}()(obj, optic)
4545
end
4646

47-
# function setall(obj, optic::ComposedFunction, vs::Tuple)
48-
# N = length(decompose(optic))
49-
# _SetAll{N}()(obj, optic, vs)
50-
# end
47+
function setall(obj, optic::ComposedFunction, vs::Tuple)
48+
N = length(decompose(optic))
49+
_SetAll{N}()(obj, optic, vs)
50+
end
5151

5252

5353
struct _GetAll{N} end
@@ -98,20 +98,30 @@ _val(::Type{Val{N}}) where {N} = N
9898
_staticsum(f, x) = sum(_val f, x) |> Val
9999

100100
getall_lengths(obj, optic) = _staticlength(getall(obj, optic))
101-
getall_lengths(obj, optic::ComposedFunction) = map(getall(obj, optic.inner)) do o
102-
getall_lengths(o, optic.outer)
101+
function getall_lengths(obj, optic::ComposedFunction, ::Val{2})
102+
map(getall(obj, optic.inner)) do o
103+
getall_lengths(o, optic.outer)
104+
end
105+
end
106+
function getall_lengths(obj, optic::ComposedFunction, ::Val{3})
107+
map(getall(obj, optic.inner)) do o
108+
getall_lengths(o, optic.outer, Val(2))
109+
end
110+
end
111+
function getall_lengths(obj, optic::ComposedFunction, ::Val{4})
112+
map(getall(obj, optic.inner)) do o
113+
getall_lengths(o, optic.outer, Val(3))
114+
end
103115
end
104116

105117

106-
# nestedsum(ls::Val) = ls
107-
# nestedsum(ls) = _staticsum(nestedsum, ls)
108118
nestedsum(ls::Type{L}) where {L <: Val} = L
109119
nestedsum(ls::Type{LS}) where {LS <: Tuple} = _staticsum(nestedsum, LS.parameters)
110120

111121

112122
to_nested_shape(vs, ls::Type{LS}) where {LS <: Val} = (@assert length(vs) == _val(LS); vs)
113-
to_nested_shape(vs, ls::LS) where {LS <: Tuple} = to_nested_shape(vs, typeof(ls))
114-
@generated function to_nested_shape(vs, ls::Type{LS}) where {LS <: Tuple}
123+
to_nested_shape(vs, ls::LS, VN) where {LS <: Tuple} = to_nested_shape(vs, typeof(ls), VN)
124+
@generated function to_nested_shape(vs, ls::Type{LS}, ::Val{2}) where {LS <: Tuple}
115125
i = 1
116126
subs = map(LS.parameters) do lss
117127
n = nestedsum(lss)
@@ -124,69 +134,55 @@ to_nested_shape(vs, ls::LS) where {LS <: Tuple} = to_nested_shape(vs, typeof(ls)
124134
end
125135
:( ($(subs...),) )
126136
end
127-
128-
129-
function setall(obj, optic::ComposedFunction, vs::Tuple)
130-
vss = to_nested_shape(vs, getall_lengths(obj, optic))
131-
# @info "" vs getall_lengths(obj, optic) vss
132-
_setall(obj, optic.inner, vss)
137+
@generated function to_nested_shape(vs, ls::Type{LS}, ::Val{3}) where {LS <: Tuple}
138+
i = 1
139+
subs = map(LS.parameters) do lss
140+
n = nestedsum(lss)
141+
elems = map(i:i+_val(n)-1) do j
142+
:( vs[$j] )
143+
end
144+
res = :( to_nested_shape(($(elems...),), $lss, Val(2)) )
145+
i = i + _val(n)
146+
res
147+
end
148+
:( ($(subs...),) )
149+
end
150+
@generated function to_nested_shape(vs, ls::Type{LS}, ::Val{4}) where {LS <: Tuple}
151+
i = 1
152+
subs = map(LS.parameters) do lss
153+
n = nestedsum(lss)
154+
elems = map(i:i+_val(n)-1) do j
155+
:( vs[$j] )
156+
end
157+
res = :( to_nested_shape(($(elems...),), $lss, Val(3)) )
158+
i = i + _val(n)
159+
res
160+
end
161+
:( ($(subs...),) )
133162
end
134163

164+
135165
_setall(obj, optic, vs) = setall(obj, optic, vs)
136-
_setall(obj, optic::ComposedFunction, vs) =
137-
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vsss
138-
# @info "" obj vsss setall(obj, optic.outer, vsss)
139-
_setall(obj, optic.outer, vsss)
166+
_setall(obj, optic::ComposedFunction, vs, ::Val{2}) =
167+
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
168+
_setall(obj, optic.outer, vss)
140169
end)
141-
142-
143-
struct _SetAll{N} end
144-
145-
function (::_SetAll{2})(obj, optic, vs)
146-
vss = to_nested_shape(vs, getall_lengths(obj, optic))
147-
# @info "" vs getall_lengths(obj, optic) vss
148-
setall(obj, optic.inner, map(getall(obj, optic.inner), vss) do obj, vsss
149-
# @info "" obj vsss setall(obj, optic.outer, vsss)
150-
setall(obj, optic.outer, vsss)
170+
_setall(obj, optic::ComposedFunction, vs, ::Val{3}) =
171+
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
172+
_setall(obj, optic.outer, vss, Val(2))
173+
end)
174+
_setall(obj, optic::ComposedFunction, vs, ::Val{4}) =
175+
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vss
176+
_setall(obj, optic.outer, vss, Val(3))
151177
end)
152-
end
153178

154-
function (::_SetAll{3})(obj, optic, vs)
155-
(f1, f2, f3) = deopcompose(optic)
156179

157-
ins_old_1 = getall(obj, f1)
158-
vss_1 = _split_getall(ins_old_1, f3 f2, vs)
159-
ins_1 = map(ins_old_1, vss_1) do o, vs_cur
160-
ins_old_2 = getall(o, f2)
161-
vss_2 = _split_getall(ins_old_2, f3, vs_cur)
162-
ins_2 = map(ins_old_2, vss_2) do o, vs_cur
163-
setall(o, f3, vs_cur)
164-
end
165-
setall(o, f2, ins_2)
166-
end
167-
setall(obj, f1, ins_1)
168-
end
180+
struct _SetAll{N} end
169181

170-
# only review SetAll{4} and helpers below
171-
function (::_SetAll{4})(o, optic, vs)
172-
(f1, f2, f3, f4) = deopcompose(optic)
173-
174-
ins_old_1 = getall(o, f1)
175-
vss_1 = _split_getall(ins_old_1, f4 f3 f2, vs)
176-
ins_1 = map(ins_old_1, vss_1) do o, vs_cur
177-
ins_old_2 = getall(o, f2)
178-
vss_2 = _split_getall(ins_old_2, f4 f3, vs_cur)
179-
ins_2 = map(ins_old_2, vss_2) do o, vs_cur
180-
ins_old_3 = getall(o, f3)
181-
vss_3 = _split_getall(ins_old_3, f4, vs_cur)
182-
ins_3 = map(ins_old_3, vss_3) do o, vs_cur
183-
setall(o, f4, vs_cur)
184-
end
185-
setall(o, f3, ins_3)
186-
end
187-
setall(o, f2, ins_2)
188-
end
189-
setall(o, f1, ins_1)
182+
function (::_SetAll{N})(obj, optic, vs) where {N}
183+
vss = to_nested_shape(vs, getall_lengths(obj, optic, Val(N)), Val(N))
184+
# @info "" vs getall_lengths(obj, optic, Val(N)) vss
185+
_setall(obj, optic, vss, Val(N))
190186
end
191187

192188
# split a into two parts: b-sized front and remaining

0 commit comments

Comments
 (0)