@@ -44,10 +44,10 @@ function getall(obj, optic::ComposedFunction)
44
44
_GetAll {N} ()(obj, optic)
45
45
end
46
46
47
- function setall (obj, optic:: ComposedFunction , vs:: Tuple )
48
- N = length (decompose (optic))
49
- _SetAll {N} ()(obj, optic, vs)
50
- end
47
+ # function setall(obj, optic::ComposedFunction, vs::Tuple)
48
+ # N = length(decompose(optic))
49
+ # _SetAll{N}()(obj, optic, vs)
50
+ # end
51
51
52
52
53
53
struct _GetAll{N} end
@@ -88,22 +88,67 @@ for i in 2:10
88
88
end
89
89
90
90
91
+ _staticlength (:: Number ) = Val (1 )
92
+ _staticlength (:: NTuple{N, <:Any} ) where {N} = Val (N)
93
+ # _staticlength(x::Vector) = length(x)
94
+
95
+ _val (:: Val{N} ) where {N} = N
96
+ _val (:: Type{Val{N}} ) where {N} = N
97
+
98
+ _staticsum (f, x) = sum (_val ∘ f, x) |> Val
99
+
100
+ getall_lengths (obj, optic) = _staticlength (getall (obj, optic))
101
+ getall_lengths (obj, optic:: ComposedFunction ) = map (getall (obj, optic. inner)) do o
102
+ getall_lengths (o, optic. outer)
103
+ end
104
+
105
+
106
+ # nestedsum(ls::Val) = ls
107
+ # nestedsum(ls) = _staticsum(nestedsum, ls)
108
+ nestedsum (ls:: Type{L} ) where {L <: Val } = L
109
+ nestedsum (ls:: Type{LS} ) where {LS <: Tuple } = _staticsum (nestedsum, LS. parameters)
110
+
111
+
112
+ to_nested_shape (vs, ls:: Type{LS} ) where {LS <: Val } = (@assert length (vs) == _val (LS); vs)
113
+ to_nested_shape (vs, ls:: LS ) where {LS <: Tuple } = to_nested_shape (vs, typeof (ls))
114
+ @generated function to_nested_shape (vs, ls:: Type{LS} ) where {LS <: Tuple }
115
+ i = 1
116
+ subs = map (LS. parameters) do lss
117
+ n = nestedsum (lss)
118
+ elems = map (i: i+ _val (n)- 1 ) do j
119
+ :( vs[$ j] )
120
+ end
121
+ res = :( to_nested_shape (($ (elems... ),), $ lss) )
122
+ i = i + _val (n)
123
+ res
124
+ end
125
+ :( ($ (subs... ),) )
126
+ end
127
+
128
+
129
+ function setall (obj, optic:: ComposedFunction , vs:: Tuple )
130
+ vss = to_nested_shape (vs, getall_lengths (obj, optic))
131
+ # @info "" vs getall_lengths(obj, optic) vss
132
+ _setall (obj, optic. inner, vss)
133
+ end
134
+
135
+ _setall (obj, optic, vs) = setall (obj, optic, vs)
136
+ _setall (obj, optic:: ComposedFunction , vs) =
137
+ setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vsss
138
+ # @info "" obj vsss setall(obj, optic.outer, vsss)
139
+ _setall (obj, optic. outer, vsss)
140
+ end )
141
+
91
142
92
143
struct _SetAll{N} end
93
144
94
- # don't review SetAll{2} and {3}
95
145
function (:: _SetAll{2} )(obj, optic, vs)
96
- (f1, f2) = deopcompose (optic)
97
-
98
- ins_old = getall (obj, f1)
99
- vss = reduce (ins_old; init= ((), vs)) do (acc, vs), o
100
- vs_cur, vs_rest = _split_n (vs, getall (o, f2))
101
- (acc... , vs_cur), vs_rest
102
- end |> first
103
- ins = map (ins_old, vss) do o, vs_cur
104
- setall (o, f2, vs_cur)
105
- end
106
- setall (obj, f1, ins)
146
+ vss = to_nested_shape (vs, getall_lengths (obj, optic))
147
+ # @info "" vs getall_lengths(obj, optic) vss
148
+ setall (obj, optic. inner, map (getall (obj, optic. inner), vss) do obj, vsss
149
+ # @info "" obj vsss setall(obj, optic.outer, vsss)
150
+ setall (obj, optic. outer, vsss)
151
+ end )
107
152
end
108
153
109
154
function (:: _SetAll{3} )(obj, optic, vs)
0 commit comments