Skip to content

Commit 9cc5804

Browse files
authored
Fix size for SubArrays (#323)
This is in the same vein as #297, where we have previously been ignoring indices in SubArray that create additional dimensions. The focus her is on size and known_size. flatten_tuples was also modified to be optionally generated so that we can use it more extensively without penalty.
1 parent 4c682b0 commit 9cc5804

File tree

3 files changed

+78
-26
lines changed

3 files changed

+78
-26
lines changed

src/indexing.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function _to_indices_expr(S::DataType, N::Int, ni, ns, is)
154154
end
155155
end
156156
push!(blk.args, Expr(:(=), :axs, :(lazy_axes(a))))
157-
push!(blk.args, :(_flatten_tuples($(indsexpr))))
157+
push!(blk.args, :(flatten_tuples($(indsexpr))))
158158
end
159159
return blk
160160
end
@@ -167,19 +167,27 @@ function _axis_expr(N::Int, d::Int)
167167
end
168168
end
169169

170-
@generated function _flatten_tuples(inds::I) where {I}
171-
t = Expr(:tuple)
172-
for i in 1:known_length(I)
173-
p = I.parameters[i]
174-
if p <: Tuple
175-
for j in 1:known_length(p)
176-
push!(t.args, :(@inbounds(getfield(getfield(inds, $i), $j))))
170+
@inline function flatten_tuples(inds::I) where {I}
171+
if @generated
172+
t = Expr(:tuple)
173+
for i in 1:fieldcount(I)
174+
p = fieldtype(I, i)
175+
if p <: Tuple
176+
for j in 1:fieldcount(p)
177+
push!(t.args, :(@inbounds(getfield(getfield(inds, $i), $j))))
178+
end
179+
else
180+
push!(t.args, :(@inbounds(getfield(inds, $i))))
177181
end
178-
else
179-
push!(t.args, :(@inbounds(getfield(inds, $i))))
180182
end
183+
Expr(:block, Expr(:meta, :inline), t)
184+
else
185+
out = ()
186+
for i in inds
187+
out = i isa Tuple ? (out..., i...) : (out..., i)
188+
end
189+
out
181190
end
182-
t
183191
end
184192

185193
"""

src/size.jl

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ size(a::Base.Broadcast.Broadcasted) = map(length, axes(a))
2727

2828
_maybe_size(::Base.HasShape{N}, a::A) where {N,A} = map(length, axes(a))
2929
_maybe_size(::Base.HasLength, a::A) where {A} = (length(a),)
30-
size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)
31-
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = length(getfield(x, dim))
30+
@inline size(x::SubArray) = flatten_tuples(map(size, x.indices))
3231
@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
3332
@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B))
3433
@inline function size(B::PermutedDimsArray{T,N,I1}) where {T,N,I1}
@@ -58,6 +57,7 @@ size(x::Iterators.Pairs) = size(getfield(x, :itr))
5857
@inline function size(x::Iterators.ProductIterator)
5958
eachop(_sub_size, ntuple(static, StaticInt(ndims(x))), getfield(x, :iterators))
6059
end
60+
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = length(getfield(x, dim))
6161

6262
size(a, dim) = size(a, to_dims(a, dim))
6363
size(a::Array, dim::CanonicalInt) = Base.arraysize(a, convert(Int, dim))
@@ -73,14 +73,6 @@ function size(a::A, dim::CanonicalInt) where {A}
7373
end
7474
end
7575
end
76-
function size(A::SubArray, dim::CanonicalInt)
77-
pdim = to_parent_dims(A, dim)
78-
if pdim > ndims(parent_type(A))
79-
return size(parent(A), pdim)
80-
else
81-
return length(A.indices[pdim])
82-
end
83-
end
8476
size(x::Iterators.Zip) = Static.reduce_tup(promote_shape, map(size, getfield(x, :is)))
8577

8678
"""
@@ -92,6 +84,41 @@ compile time. If a dimension does not have a known size along a dimension then `
9284
returned in its position.
9385
"""
9486
known_size(x) = known_size(typeof(x))
87+
@inline known_size(@nospecialize T::Type{<:Number}) = ()
88+
@inline known_size(@nospecialize T::Type{<:VecAdjTrans}) = (1, known_length(parent_type(T)))
89+
@inline function known_size(@nospecialize T::Type{<:MatAdjTrans})
90+
s1, s2 = known_size(parent_type(T))
91+
(s2, s1)
92+
end
93+
function known_size(@nospecialize T::Type{<:Diagonal})
94+
s = known_length(parent_type(T))
95+
(s, s)
96+
end
97+
known_size(@nospecialize T::Type{<:Union{Symmetric,Hermitian}}) = known_size(parent_type(T))
98+
@inline function known_size(::Type{<:Base.ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
99+
psize = known_size(A)
100+
if IsReshaped
101+
if sizeof(S) > sizeof(T)
102+
return (div(sizeof(S), sizeof(T)), psize...)
103+
elseif sizeof(S) < sizeof(T)
104+
return Base.tail(psize)
105+
else
106+
return psize
107+
end
108+
else
109+
if Base.issingletontype(T) || first(psize) === nothing
110+
return psize
111+
else
112+
return (div(first(psize) * sizeof(S), sizeof(T)), Base.tail(psize)...)
113+
end
114+
end
115+
end
116+
117+
@inline function known_size(::Type{<:PermutedDimsArray{<:Any,N,I1,<:Any,P}}) where {N,I1,P}
118+
sz = known_size(P)
119+
ntuple(i -> getfield(sz, getfield(I1, i)), Val{N}())
120+
end
121+
95122
@inline function known_size(::Type{T}) where {T}
96123
if is_forwarding_wrapper(T)
97124
return known_size(parent_type(T))
@@ -103,17 +130,31 @@ function _maybe_known_size(::Base.HasShape{N}, ::Type{T}) where {N,T}
103130
eachop(_known_size, ntuple(static, StaticInt(N)), axes_types(T))
104131
end
105132
_maybe_known_size(::Base.IteratorSize, ::Type{T}) where {T} = (known_length(T),)
106-
function known_size(::Type{T}) where {T<:AbstractRange}
107-
(_range_length(known_first(T), known_step(T), known_last(T)),)
108-
end
109133
known_size(::Type{<:Base.IdentityUnitRange{I}}) where {I} = known_size(I)
110134
known_size(::Type{<:Base.Generator{I}}) where {I} = known_size(I)
111135
known_size(::Type{<:Iterators.Reverse{I}}) where {I} = known_size(I)
112136
known_size(::Type{<:Iterators.Enumerate{I}}) where {I} = known_size(I)
113137
known_size(::Type{<:Iterators.Accumulate{<:Any,I}}) where {I} = known_size(I)
114138
known_size(::Type{<:Iterators.Pairs{<:Any,<:Any,I}}) where {I} = known_size(I)
115139
@inline function known_size(::Type{<:Iterators.ProductIterator{T}}) where {T}
116-
eachop(_known_size, ntuple(static, StaticInt(known_length(T))), T)
140+
ntuple(i -> known_length(T.parameters[i]), Val(known_length(T)))
141+
end
142+
@inline function known_size(@nospecialize T::Type{<:AbstractRange})
143+
if is_forwarding_wrapper(T)
144+
return known_size(parent_type(T))
145+
else
146+
return (_range_length(known_first(T), known_step(T), known_last(T)),)
147+
end
148+
end
149+
@inline function known_size(@nospecialize T::Type{<:Union{LinearIndices,CartesianIndices}})
150+
I = fieldtype(T, :indices)
151+
ntuple(i -> known_length(I.parameters[i]), Val(ndims(T)))
152+
end
153+
@inline function known_size(@nospecialize T::Type{<:SubArray})
154+
_known_sub_sizes(fieldtype(T, :indices))
155+
end
156+
@inline function _known_sub_sizes(T::Type{<:Tuple})
157+
flatten_tuples(ntuple(i -> known_size(T.parameters[i]), Val(known_length(T))))
117158
end
118159

119160
# 1. `Zip` doesn't check that its collections are compatible (same size) at construction,

test/size.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
@test @inferred(ArrayInterface.size(A2)) === (4, 3, 5)
1212
@test @inferred(ArrayInterface.size(A2r)) === (2, 3, 5)
1313

14+
@test @inferred(ArrayInterface.size(view(rand(4), reshape(1:4, 2, 2)))) == (2, 2)
1415
@test @inferred(ArrayInterface.size(irev)) === (StaticInt(2), StaticInt(3), StaticInt(4))
1516
@test @inferred(ArrayInterface.size(iprod)) === (StaticInt(2), StaticInt(3), StaticInt(4))
1617
@test @inferred(ArrayInterface.size(iflat)) === (static(72),)
@@ -38,6 +39,8 @@
3839
@test @inferred(ArrayInterface.size(Mp)) == size(Mp)
3940
@test @inferred(ArrayInterface.size(Mp2)) == size(Mp2)
4041

42+
@test @inferred(ArrayInterface.known_size(1)) === ()
43+
@test @inferred(ArrayInterface.known_size(view(rand(4), reshape(1:4, 2, 2)))) == (nothing, nothing)
4144
@test @inferred(ArrayInterface.known_size(A)) === (nothing, nothing, nothing)
4245
@test @inferred(ArrayInterface.known_size(Ap)) === (nothing, nothing)
4346
@test @inferred(ArrayInterface.known_size(Wrapper(Ap))) === (nothing, nothing)
@@ -98,4 +101,4 @@ end
98101
e = permutedims(d)
99102
@test @inferred(ArrayInterface.is_lazy_conjugate(e)) == false
100103
@test @inferred(ArrayInterface.is_lazy_conjugate([1, 2, 3]')) == false # We don't care about conj on `<:Real`
101-
end
104+
end

0 commit comments

Comments
 (0)