Skip to content

Commit 9ed3188

Browse files
committed
support vectors
1 parent 165fcc8 commit 9ed3188

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

src/getsetall.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ getall(obj, f) = (f(obj),)
3131

3232
setall(obj, ::Properties, vs::Tuple) = setproperties(obj, NamedTuple{propertynames(obj)}(vs))
3333
setall(obj::NamedTuple{NS}, ::Elements, vs::Tuple) where {NS} = NamedTuple{NS}(vs)
34-
setall(obj::NTuple{N, Any}, ::Elements, vs::NTuple{N, Any}) where {N} = vs
35-
setall(obj, o::If, vs::Tuple) = error("Not supported")
36-
setall(obj, o, vs::Tuple) = set(obj, o, only(vs))
34+
setall(obj::NTuple{N, Any}, ::Elements, vs) where {N} = (@assert length(vs) == N; Tuple(vs))
35+
setall(obj::AbstractArray, ::Elements, vs::AbstractArray) = (@assert length(obj) == length(vs); reshape(vs, size(obj)))
36+
setall(obj::AbstractArray, ::Elements, vs) = setall(obj, Elements(), collect(vs))
37+
setall(obj, o::If, vs) = error("Not supported")
38+
setall(obj, o, vs) = set(obj, o, only(vs))
3739

3840

3941
# A recursive implementation of getall doesn't actually infer,
@@ -46,7 +48,7 @@ end
4648

4749
function setall(obj, optic::ComposedFunction, vs)
4850
N = length(decompose(optic))
49-
vss = to_nested_shape(vs, typeof(getall_lengths(obj, optic, Val(N))), Val(N))
51+
vss = to_nested_shape(vs, Val(getall_lengths(obj, optic, Val(N))), Val(N))
5052
_setall(obj, optic, vss, Val(N))
5153
end
5254

@@ -91,12 +93,12 @@ end
9193

9294
_staticlength(::Number) = Val(1)
9395
_staticlength(::NTuple{N, <:Any}) where {N} = Val(N)
94-
# _staticlength(x::AbstractVector) = length(x)
96+
_staticlength(x::AbstractVector) = length(x)
9597

98+
_val(N::Int) = N
9699
_val(::Val{N}) where {N} = N
97100
_val(::Type{Val{N}}) where {N} = N
98101

99-
_staticsum(f, x) = sum(_val f, x) |> Val
100102

101103
getall_lengths(obj, optic, ::Val{1}) = _staticlength(getall(obj, optic))
102104
for i in 2:10
@@ -108,20 +110,21 @@ for i in 2:10
108110
end
109111

110112

111-
nestedsum(ls::Type{L}) where {L <: Val} = L
112-
nestedsum(ls::Type{LS}) where {LS <: Tuple} = _staticsum(nestedsum, LS.parameters)
113+
nestedsum(ls::Int) = ls
114+
nestedsum(ls::Val) = ls
115+
nestedsum(ls::Tuple) = sum(_val nestedsum, ls)
113116

114117

115-
to_nested_shape(vs, ls::Type{LS}, ::Val{1}) where {LS <: Val} = (@assert length(vs) == _val(LS); vs)
118+
to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs)
116119
for i in 2:10
117-
@eval @generated function to_nested_shape(vs, ls::Type{LS}, ::Val{$i}) where {LS <: Tuple}
120+
@eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS}
118121
vi = 1
119-
subs = map(LS.parameters) do lss
122+
subs = map(LS) do lss
120123
n = nestedsum(lss)
121124
elems = map(vi:vi+_val(n)-1) do j
122125
:( vs[$j] )
123126
end
124-
res = :( to_nested_shape(($(elems...),), $lss, $(Val($(i - 1)))) )
127+
res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) )
125128
vi = vi + _val(n)
126129
res
127130
end

test/test_getsetall.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ if VERSION >= v"1.6" # for ComposedFunction
6666
end
6767

6868
@testset "setall" begin
69+
@test (2,) === @inferred setall((1,), Elements(), (2,))
70+
@test (2,) === setall((1,), Elements(), [2,])
71+
@test [2,] == @inferred setall([1,], Elements(), (2,))
72+
@test [2,] == @inferred setall([1,], Elements(), [2,])
73+
6974
obj = (a=1, b=2.0, c='3')
7075
@test (a="aa", b=2.0, c='3') === @inferred setall(obj, @optic(_.a), ("aa",))
7176
@test (a="aa", b=1, c='5') === @inferred setall(obj, Properties(), ("aa", 1, '5'))
@@ -77,6 +82,11 @@ end
7782
@test (a=1, b=((c=:x, d=4), (c=:y, d=6))) === @inferred setall(obj, @optic(_.b |> Elements() |> _.c), (:x, :y))
7883
@test (a=1, b=((c=:x, d="y"), (c=:z, d=10))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties()), (:x, "y", :z, 10))
7984
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), (-9, -12, -15, -18))
85+
@test (a=1, b=((c=-3., d=-4.), (c=-5., d=-6.))) === @inferred setall(obj, @optic(_.b |> Elements() |> Properties() |> _ * 3), [-9, -12, -15, -18])
86+
87+
obj = ([1, 2], 3:5, (6,))
88+
@test [1, 2, 3, 4, 5, 6] == @inferred getall(obj, @optic _ |> Elements() |> Elements())
89+
@test [2, 3, 4, 5, 6, 7] == @inferred getall(obj, @optic _ |> Elements() |> Elements() |> _ + 1)
8090
end
8191

8292
end

0 commit comments

Comments
 (0)