@@ -31,9 +31,11 @@ getall(obj, f) = (f(obj),)
31
31
32
32
setall (obj, :: Properties , vs:: Tuple ) = setproperties (obj, NamedTuple {propertynames(obj)} (vs))
33
33
setall (obj:: NamedTuple{NS} , :: Elements , vs:: Tuple ) where {NS} = NamedTuple {NS} (vs)
34
- setall (obj:: NTuple{N, Any} , :: Elements , vs:: NTuple{N, Any} ) where {N} = vs
35
- setall (obj, o:: If , vs:: Tuple ) = error (" Not supported" )
36
- setall (obj, o, vs:: Tuple ) = set (obj, o, only (vs))
34
+ setall (obj:: NTuple{N, Any} , :: Elements , vs) where {N} = (@assert length (vs) == N; Tuple (vs))
35
+ setall (obj:: AbstractArray , :: Elements , vs:: AbstractArray ) = (@assert length (obj) == length (vs); reshape (vs, size (obj)))
36
+ setall (obj:: AbstractArray , :: Elements , vs) = setall (obj, Elements (), collect (vs))
37
+ setall (obj, o:: If , vs) = error (" Not supported" )
38
+ setall (obj, o, vs) = set (obj, o, only (vs))
37
39
38
40
39
41
# A recursive implementation of getall doesn't actually infer,
46
48
47
49
function setall (obj, optic:: ComposedFunction , vs)
48
50
N = length (decompose (optic))
49
- vss = to_nested_shape (vs, typeof (getall_lengths (obj, optic, Val (N))), Val (N))
51
+ vss = to_nested_shape (vs, Val (getall_lengths (obj, optic, Val (N))), Val (N))
50
52
_setall (obj, optic, vss, Val (N))
51
53
end
52
54
91
93
92
94
_staticlength (:: Number ) = Val (1 )
93
95
_staticlength (:: NTuple{N, <:Any} ) where {N} = Val (N)
94
- # _staticlength(x::AbstractVector) = length(x)
96
+ _staticlength (x:: AbstractVector ) = length (x)
95
97
98
+ _val (N:: Int ) = N
96
99
_val (:: Val{N} ) where {N} = N
97
100
_val (:: Type{Val{N}} ) where {N} = N
98
101
99
- _staticsum (f, x) = sum (_val ∘ f, x) |> Val
100
102
101
103
getall_lengths (obj, optic, :: Val{1} ) = _staticlength (getall (obj, optic))
102
104
for i in 2 : 10
@@ -108,20 +110,21 @@ for i in 2:10
108
110
end
109
111
110
112
111
- nestedsum (ls:: Type{L} ) where {L <: Val } = L
112
- nestedsum (ls:: Type{LS} ) where {LS <: Tuple } = _staticsum (nestedsum, LS. parameters)
113
+ nestedsum (ls:: Int ) = ls
114
+ nestedsum (ls:: Val ) = ls
115
+ nestedsum (ls:: Tuple ) = sum (_val ∘ nestedsum, ls)
113
116
114
117
115
- to_nested_shape (vs, ls :: Type {LS} , :: Val{1} ) where {LS <: Val } = (@assert length (vs) == _val (LS); vs)
118
+ to_nested_shape (vs, :: Val {LS} , :: Val{1} ) where {LS} = (@assert length (vs) == _val (LS); vs)
116
119
for i in 2 : 10
117
- @eval @generated function to_nested_shape (vs, ls:: Type {LS} , :: Val{$i} ) where {LS <: Tuple }
120
+ @eval @generated function to_nested_shape (vs, ls:: Val {LS} , :: Val{$i} ) where {LS}
118
121
vi = 1
119
- subs = map (LS. parameters ) do lss
122
+ subs = map (LS) do lss
120
123
n = nestedsum (lss)
121
124
elems = map (vi: vi+ _val (n)- 1 ) do j
122
125
:( vs[$ j] )
123
126
end
124
- res = :( to_nested_shape (($ (elems... ),), $ lss, $ (Val ($ (i - 1 )))) )
127
+ res = :( to_nested_shape (($ (elems... ),), $ ( Val ( lss)) , $ (Val ($ (i - 1 )))) )
125
128
vi = vi + _val (n)
126
129
res
127
130
end
0 commit comments