Skip to content

Commit aee5552

Browse files
committed
straightforward solution
1 parent 2b061d8 commit aee5552

File tree

1 file changed

+61
-16
lines changed

1 file changed

+61
-16
lines changed

src/getsetall.jl

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ function getall(obj, optic::ComposedFunction)
4444
_GetAll{N}()(obj, optic)
4545
end
4646

47-
function setall(obj, optic::ComposedFunction, vs::Tuple)
48-
N = length(decompose(optic))
49-
_SetAll{N}()(obj, optic, vs)
50-
end
47+
# function setall(obj, optic::ComposedFunction, vs::Tuple)
48+
# N = length(decompose(optic))
49+
# _SetAll{N}()(obj, optic, vs)
50+
# end
5151

5252

5353
struct _GetAll{N} end
@@ -88,22 +88,67 @@ for i in 2:10
8888
end
8989

9090

91+
_staticlength(::Number) = Val(1)
92+
_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
93+
# _staticlength(x::Vector) = length(x)
94+
95+
_val(::Val{N}) where {N} = N
96+
_val(::Type{Val{N}}) where {N} = N
97+
98+
_staticsum(f, x) = sum(_val f, x) |> Val
99+
100+
getall_lengths(obj, optic) = _staticlength(getall(obj, optic))
101+
getall_lengths(obj, optic::ComposedFunction) = map(getall(obj, optic.inner)) do o
102+
getall_lengths(o, optic.outer)
103+
end
104+
105+
106+
# nestedsum(ls::Val) = ls
107+
# nestedsum(ls) = _staticsum(nestedsum, ls)
108+
nestedsum(ls::Type{L}) where {L <: Val} = L
109+
nestedsum(ls::Type{LS}) where {LS <: Tuple} = _staticsum(nestedsum, LS.parameters)
110+
111+
112+
to_nested_shape(vs, ls::Type{LS}) where {LS <: Val} = (@assert length(vs) == _val(LS); vs)
113+
to_nested_shape(vs, ls::LS) where {LS <: Tuple} = to_nested_shape(vs, typeof(ls))
114+
@generated function to_nested_shape(vs, ls::Type{LS}) where {LS <: Tuple}
115+
i = 1
116+
subs = map(LS.parameters) do lss
117+
n = nestedsum(lss)
118+
elems = map(i:i+_val(n)-1) do j
119+
:( vs[$j] )
120+
end
121+
res = :( to_nested_shape(($(elems...),), $lss) )
122+
i = i + _val(n)
123+
res
124+
end
125+
:( ($(subs...),) )
126+
end
127+
128+
129+
function setall(obj, optic::ComposedFunction, vs::Tuple)
130+
vss = to_nested_shape(vs, getall_lengths(obj, optic))
131+
# @info "" vs getall_lengths(obj, optic) vss
132+
_setall(obj, optic.inner, vss)
133+
end
134+
135+
_setall(obj, optic, vs) = setall(obj, optic, vs)
136+
_setall(obj, optic::ComposedFunction, vs) =
137+
setall(obj, optic.inner, map(getall(obj, optic.inner), vs) do obj, vsss
138+
# @info "" obj vsss setall(obj, optic.outer, vsss)
139+
_setall(obj, optic.outer, vsss)
140+
end)
141+
91142

92143
struct _SetAll{N} end
93144

94-
# don't review SetAll{2} and {3}
95145
function (::_SetAll{2})(obj, optic, vs)
96-
(f1, f2) = deopcompose(optic)
97-
98-
ins_old = getall(obj, f1)
99-
vss = reduce(ins_old; init=((), vs)) do (acc, vs), o
100-
vs_cur, vs_rest = _split_n(vs, getall(o, f2))
101-
(acc..., vs_cur), vs_rest
102-
end |> first
103-
ins = map(ins_old, vss) do o, vs_cur
104-
setall(o, f2, vs_cur)
105-
end
106-
setall(obj, f1, ins)
146+
vss = to_nested_shape(vs, getall_lengths(obj, optic))
147+
# @info "" vs getall_lengths(obj, optic) vss
148+
setall(obj, optic.inner, map(getall(obj, optic.inner), vss) do obj, vsss
149+
# @info "" obj vsss setall(obj, optic.outer, vsss)
150+
setall(obj, optic.outer, vsss)
151+
end)
107152
end
108153

109154
function (::_SetAll{3})(obj, optic, vs)

0 commit comments

Comments
 (0)