@@ -44,10 +44,10 @@ function getall(obj, optic::ComposedFunction)
44
44
_GetAll {N} ()(obj, optic)
45
45
end
46
46
47
- # function setall(obj, optic::ComposedFunction, vs::Tuple)
48
- # N = length(decompose(optic))
49
- # _SetAll{N}()(obj, optic, vs)
50
- # end
47
+ function setall (obj, optic:: ComposedFunction , vs:: Tuple )
48
+ N = length (decompose (optic))
49
+ _SetAll {N} ()(obj, optic, vs)
50
+ end
51
51
52
52
53
53
struct _GetAll{N} end
@@ -98,20 +98,30 @@ _val(::Type{Val{N}}) where {N} = N
98
98
_staticsum (f, x) = sum (_val ∘ f, x) |> Val
99
99
100
100
getall_lengths (obj, optic) = _staticlength (getall (obj, optic))
101
- getall_lengths (obj, optic:: ComposedFunction ) = map (getall (obj, optic. inner)) do o
102
- getall_lengths (o, optic. outer)
101
+ function getall_lengths (obj, optic:: ComposedFunction , :: Val{2} )
102
+ map (getall (obj, optic. inner)) do o
103
+ getall_lengths (o, optic. outer)
104
+ end
105
+ end
106
+ function getall_lengths (obj, optic:: ComposedFunction , :: Val{3} )
107
+ map (getall (obj, optic. inner)) do o
108
+ getall_lengths (o, optic. outer, Val (2 ))
109
+ end
110
+ end
111
+ function getall_lengths (obj, optic:: ComposedFunction , :: Val{4} )
112
+ map (getall (obj, optic. inner)) do o
113
+ getall_lengths (o, optic. outer, Val (3 ))
114
+ end
103
115
end
104
116
105
117
106
- # nestedsum(ls::Val) = ls
107
- # nestedsum(ls) = _staticsum(nestedsum, ls)
108
118
nestedsum (ls:: Type{L} ) where {L <: Val } = L
109
119
nestedsum (ls:: Type{LS} ) where {LS <: Tuple } = _staticsum (nestedsum, LS. parameters)
110
120
111
121
112
122
to_nested_shape (vs, ls:: Type{LS} ) where {LS <: Val } = (@assert length (vs) == _val (LS); vs)
113
- to_nested_shape (vs, ls:: LS ) where {LS <: Tuple } = to_nested_shape (vs, typeof (ls))
114
- @generated function to_nested_shape (vs, ls:: Type{LS} ) where {LS <: Tuple }
123
+ to_nested_shape (vs, ls:: LS , VN ) where {LS <: Tuple } = to_nested_shape (vs, typeof (ls), VN )
124
+ @generated function to_nested_shape (vs, ls:: Type{LS} , :: Val{2} ) where {LS <: Tuple }
115
125
i = 1
116
126
subs = map (LS. parameters) do lss
117
127
n = nestedsum (lss)
@@ -124,69 +134,55 @@ to_nested_shape(vs, ls::LS) where {LS <: Tuple} = to_nested_shape(vs, typeof(ls)
124
134
end
125
135
:( ($ (subs... ),) )
126
136
end
127
-
128
-
129
- function setall (obj, optic:: ComposedFunction , vs:: Tuple )
130
- vss = to_nested_shape (vs, getall_lengths (obj, optic))
131
- # @info "" vs getall_lengths(obj, optic) vss
132
- _setall (obj, optic. inner, vss)
137
+ @generated function to_nested_shape (vs, ls:: Type{LS} , :: Val{3} ) where {LS <: Tuple }
138
+ i = 1
139
+ subs = map (LS. parameters) do lss
140
+ n = nestedsum (lss)
141
+ elems = map (i: i+ _val (n)- 1 ) do j
142
+ :( vs[$ j] )
143
+ end
144
+ res = :( to_nested_shape (($ (elems... ),), $ lss, Val (2 )) )
145
+ i = i + _val (n)
146
+ res
147
+ end
148
+ :( ($ (subs... ),) )
149
+ end
150
+ @generated function to_nested_shape (vs, ls:: Type{LS} , :: Val{4} ) where {LS <: Tuple }
151
+ i = 1
152
+ subs = map (LS. parameters) do lss
153
+ n = nestedsum (lss)
154
+ elems = map (i: i+ _val (n)- 1 ) do j
155
+ :( vs[$ j] )
156
+ end
157
+ res = :( to_nested_shape (($ (elems... ),), $ lss, Val (3 )) )
158
+ i = i + _val (n)
159
+ res
160
+ end
161
+ :( ($ (subs... ),) )
133
162
end
134
163
164
+
135
165
_setall (obj, optic, vs) = setall (obj, optic, vs)
136
- _setall (obj, optic:: ComposedFunction , vs) =
137
- setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vsss
138
- # @info "" obj vsss setall(obj, optic.outer, vsss)
139
- _setall (obj, optic. outer, vsss)
166
+ _setall (obj, optic:: ComposedFunction , vs, :: Val{2} ) =
167
+ setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vss
168
+ _setall (obj, optic. outer, vss)
140
169
end )
141
-
142
-
143
- struct _SetAll{N} end
144
-
145
- function (:: _SetAll{2} )(obj, optic, vs)
146
- vss = to_nested_shape (vs, getall_lengths (obj, optic))
147
- # @info "" vs getall_lengths(obj, optic) vss
148
- setall (obj, optic. inner, map (getall (obj, optic. inner), vss) do obj, vsss
149
- # @info "" obj vsss setall(obj, optic.outer, vsss)
150
- setall (obj, optic. outer, vsss)
170
+ _setall (obj, optic:: ComposedFunction , vs, :: Val{3} ) =
171
+ setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vss
172
+ _setall (obj, optic. outer, vss, Val (2 ))
173
+ end )
174
+ _setall (obj, optic:: ComposedFunction , vs, :: Val{4} ) =
175
+ setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vss
176
+ _setall (obj, optic. outer, vss, Val (3 ))
151
177
end )
152
- end
153
178
154
- function (:: _SetAll{3} )(obj, optic, vs)
155
- (f1, f2, f3) = deopcompose (optic)
156
179
157
- ins_old_1 = getall (obj, f1)
158
- vss_1 = _split_getall (ins_old_1, f3 ∘ f2, vs)
159
- ins_1 = map (ins_old_1, vss_1) do o, vs_cur
160
- ins_old_2 = getall (o, f2)
161
- vss_2 = _split_getall (ins_old_2, f3, vs_cur)
162
- ins_2 = map (ins_old_2, vss_2) do o, vs_cur
163
- setall (o, f3, vs_cur)
164
- end
165
- setall (o, f2, ins_2)
166
- end
167
- setall (obj, f1, ins_1)
168
- end
180
+ struct _SetAll{N} end
169
181
170
- # only review SetAll{4} and helpers below
171
- function (:: _SetAll{4} )(o, optic, vs)
172
- (f1, f2, f3, f4) = deopcompose (optic)
173
-
174
- ins_old_1 = getall (o, f1)
175
- vss_1 = _split_getall (ins_old_1, f4 ∘ f3 ∘ f2, vs)
176
- ins_1 = map (ins_old_1, vss_1) do o, vs_cur
177
- ins_old_2 = getall (o, f2)
178
- vss_2 = _split_getall (ins_old_2, f4 ∘ f3, vs_cur)
179
- ins_2 = map (ins_old_2, vss_2) do o, vs_cur
180
- ins_old_3 = getall (o, f3)
181
- vss_3 = _split_getall (ins_old_3, f4, vs_cur)
182
- ins_3 = map (ins_old_3, vss_3) do o, vs_cur
183
- setall (o, f4, vs_cur)
184
- end
185
- setall (o, f3, ins_3)
186
- end
187
- setall (o, f2, ins_2)
188
- end
189
- setall (o, f1, ins_1)
182
+ function (:: _SetAll{N} )(obj, optic, vs) where {N}
183
+ vss = to_nested_shape (vs, getall_lengths (obj, optic, Val (N)), Val (N))
184
+ # @info "" vs getall_lengths(obj, optic, Val(N)) vss
185
+ _setall (obj, optic, vss, Val (N))
190
186
end
191
187
192
188
# split a into two parts: b-sized front and remaining
0 commit comments