@@ -44,9 +44,10 @@ function getall(obj, optic::ComposedFunction)
44
44
_GetAll {N} ()(obj, optic)
45
45
end
46
46
47
- function setall (obj, optic:: ComposedFunction , vs:: Tuple )
47
+ function setall (obj, optic:: ComposedFunction , vs)
48
48
N = length (decompose (optic))
49
- _SetAll {N} ()(obj, optic, vs)
49
+ vss = to_nested_shape (vs, typeof (getall_lengths (obj, optic, Val (N))), Val (N))
50
+ _setall (obj, optic, vss, Val (N))
50
51
end
51
52
52
53
90
91
91
92
_staticlength (:: Number ) = Val (1 )
92
93
_staticlength (:: NTuple{N, <:Any} ) where {N} = Val (N)
93
- # _staticlength(x::Vector ) = length(x)
94
+ # _staticlength(x::AbstractVector ) = length(x)
94
95
95
96
_val (:: Val{N} ) where {N} = N
96
97
_val (:: Type{Val{N}} ) where {N} = N
97
98
98
99
_staticsum (f, x) = sum (_val ∘ f, x) |> Val
99
100
100
- getall_lengths (obj, optic) = _staticlength (getall (obj, optic))
101
- function getall_lengths (obj, optic:: ComposedFunction , :: Val{2} )
102
- map (getall (obj, optic. inner)) do o
103
- getall_lengths (o, optic. outer)
104
- end
105
- end
106
- function getall_lengths (obj, optic:: ComposedFunction , :: Val{3} )
107
- map (getall (obj, optic. inner)) do o
108
- getall_lengths (o, optic. outer, Val (2 ))
109
- end
110
- end
111
- function getall_lengths (obj, optic:: ComposedFunction , :: Val{4} )
112
- map (getall (obj, optic. inner)) do o
113
- getall_lengths (o, optic. outer, Val (3 ))
101
+ getall_lengths (obj, optic, :: Val{1} ) = _staticlength (getall (obj, optic))
102
+ for i in 2 : 10
103
+ @eval function getall_lengths (obj, optic:: ComposedFunction , :: Val{$i} )
104
+ map (getall (obj, optic. inner)) do o
105
+ getall_lengths (o, optic. outer, Val ($ (i - 1 )))
106
+ end
114
107
end
115
108
end
116
109
@@ -119,77 +112,30 @@ nestedsum(ls::Type{L}) where {L <: Val} = L
119
112
nestedsum (ls:: Type{LS} ) where {LS <: Tuple } = _staticsum (nestedsum, LS. parameters)
120
113
121
114
122
- to_nested_shape (vs, ls:: Type{LS} ) where {LS <: Val } = (@assert length (vs) == _val (LS); vs)
123
- to_nested_shape (vs, ls:: LS , VN) where {LS <: Tuple } = to_nested_shape (vs, typeof (ls), VN)
124
- @generated function to_nested_shape (vs, ls:: Type{LS} , :: Val{2} ) where {LS <: Tuple }
125
- i = 1
126
- subs = map (LS. parameters) do lss
127
- n = nestedsum (lss)
128
- elems = map (i: i+ _val (n)- 1 ) do j
129
- :( vs[$ j] )
130
- end
131
- res = :( to_nested_shape (($ (elems... ),), $ lss) )
132
- i = i + _val (n)
133
- res
134
- end
135
- :( ($ (subs... ),) )
136
- end
137
- @generated function to_nested_shape (vs, ls:: Type{LS} , :: Val{3} ) where {LS <: Tuple }
138
- i = 1
139
- subs = map (LS. parameters) do lss
140
- n = nestedsum (lss)
141
- elems = map (i: i+ _val (n)- 1 ) do j
142
- :( vs[$ j] )
143
- end
144
- res = :( to_nested_shape (($ (elems... ),), $ lss, Val (2 )) )
145
- i = i + _val (n)
146
- res
147
- end
148
- :( ($ (subs... ),) )
149
- end
150
- @generated function to_nested_shape (vs, ls:: Type{LS} , :: Val{4} ) where {LS <: Tuple }
151
- i = 1
152
- subs = map (LS. parameters) do lss
153
- n = nestedsum (lss)
154
- elems = map (i: i+ _val (n)- 1 ) do j
155
- :( vs[$ j] )
115
+ to_nested_shape (vs, ls:: Type{LS} , :: Val{1} ) where {LS <: Val } = (@assert length (vs) == _val (LS); vs)
116
+ for i in 2 : 10
117
+ @eval @generated function to_nested_shape (vs, ls:: Type{LS} , :: Val{$i} ) where {LS <: Tuple }
118
+ vi = 1
119
+ subs = map (LS. parameters) do lss
120
+ n = nestedsum (lss)
121
+ elems = map (vi: vi+ _val (n)- 1 ) do j
122
+ :( vs[$ j] )
123
+ end
124
+ res = :( to_nested_shape (($ (elems... ),), $ lss, $ (Val ($ (i - 1 )))) )
125
+ vi = vi + _val (n)
126
+ res
156
127
end
157
- res = :( to_nested_shape (($ (elems... ),), $ lss, Val (3 )) )
158
- i = i + _val (n)
159
- res
128
+ :( ($ (subs... ),) )
160
129
end
161
- :( ($ (subs... ),) )
162
130
end
163
131
164
132
165
- _setall (obj, optic, vs) = setall (obj, optic, vs)
166
- _setall (obj, optic:: ComposedFunction , vs, :: Val{2} ) =
167
- setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vss
168
- _setall (obj, optic. outer, vss)
169
- end )
170
- _setall (obj, optic:: ComposedFunction , vs, :: Val{3} ) =
171
- setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vss
172
- _setall (obj, optic. outer, vss, Val (2 ))
173
- end )
174
- _setall (obj, optic:: ComposedFunction , vs, :: Val{4} ) =
175
- setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vss
176
- _setall (obj, optic. outer, vss, Val (3 ))
177
- end )
178
-
179
-
180
- struct _SetAll{N} end
181
-
182
- function (:: _SetAll{N} )(obj, optic, vs) where {N}
183
- vss = to_nested_shape (vs, getall_lengths (obj, optic, Val (N)), Val (N))
184
- # @info "" vs getall_lengths(obj, optic, Val(N)) vss
185
- _setall (obj, optic, vss, Val (N))
133
+ _setall (obj, optic, vs, :: Val{1} ) = setall (obj, optic, vs)
134
+ for i in 2 : 10
135
+ @eval function _setall (obj, optic:: ComposedFunction , vs, :: Val{$i} )
136
+ setall (obj, optic. inner, map (getall (obj, optic. inner), vs) do obj, vss
137
+ _setall (obj, optic. outer, vss, Val ($ (i - 1 )))
138
+ end )
139
+ end
186
140
end
187
141
188
- # split a into two parts: b-sized front and remaining
189
- _split_n (a:: NTuple{Na, Any} , b:: NTuple{Nb, Any} ) where {Na, Nb} = ntuple (i -> a[i], Nb), ntuple (i -> a[Nb + i], Na - Nb)
190
-
191
- # split vs into parts sized according to getall(o, f)
192
- _split_getall (ins_old, f, vs) = foldl (ins_old; init= ((), vs)) do (acc, vs_), o
193
- vs_cur, vs_rest = _split_n (vs_, getall (o, f))
194
- (acc... , vs_cur), vs_rest
195
- end |> first
0 commit comments