Skip to content

Commit 9a422f8

Browse files
committed
Slightly increase support for reshaped arrays
1 parent f44c82b commit 9a422f8

File tree

3 files changed

+78
-51
lines changed

3 files changed

+78
-51
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.14.12"
3+
version = "2.14.13"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/stridelayout.jl

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787

8888

8989
"""
90-
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
90+
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{Val}}
9191
9292
Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
9393
"""
@@ -98,53 +98,6 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
9898
Base.@pure contiguous_axis_indicator(::Contiguous{N}, ::Val{D}) where {N,D} =
9999
ntuple(d -> Val{d == N}(), Val{D}())
100100

101-
"""
102-
If the contiguous dimension is not the dimension with `StrideRank{1}`:
103-
"""
104-
struct ContiguousBatch{N} end
105-
Base.@pure ContiguousBatch(N::Int) = ContiguousBatch{N}()
106-
_get(::ContiguousBatch{N}) where {N} = N
107-
108-
"""
109-
contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
110-
111-
Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
112-
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
113-
If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
114-
If unknown, it will return `nothing`.
115-
"""
116-
contiguous_batch_size(x) = contiguous_batch_size(typeof(x))
117-
contiguous_batch_size(::Type) = nothing
118-
contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = ContiguousBatch{0}()
119-
contiguous_batch_size(::Type{<:Tuple}) = ContiguousBatch{0}()
120-
contiguous_batch_size(
121-
::Type{<:Union{Transpose{T,A},Adjoint{T,A}}},
122-
) where {T,A<:AbstractVecOrMat{T}} = contiguous_batch_size(A)
123-
contiguous_batch_size(
124-
::Type{<:PermutedDimsArray{T,N,I1,I2,A}},
125-
) where {T,N,I1,I2,A<:AbstractArray{T,N}} = contiguous_batch_size(A)
126-
function contiguous_batch_size(
127-
::Type{S},
128-
) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
129-
_contiguous_batch_size(S, contiguous_batch_size(A), contiguous_axis(A))
130-
end
131-
_contiguous_batch_size(::Any, ::Any, ::Any) = nothing
132-
@generated function _contiguous_batch_size(
133-
::Type{S},
134-
::ContiguousBatch{B},
135-
::Contiguous{C},
136-
) where {B,C,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
137-
if I.parameters[C] <: AbstractUnitRange
138-
Expr(:call, Expr(:curly, :ContiguousBatch, B))
139-
else
140-
Expr(:call, Expr(:curly, :ContiguousBatch, -1))
141-
end
142-
end
143-
144-
contiguous_batch_size(
145-
::Type{R},
146-
) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} = ContiguousBatch{0}()
147-
148101
struct StrideRank{R} end
149102
Base.@pure StrideRank(R::NTuple{<:Any,Int}) = StrideRank{R}()
150103
_get(::StrideRank{R}) where {R} = R
@@ -230,6 +183,67 @@ stride_rank(x, i) = stride_rank(x)[i]
230183
stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} =
231184
StrideRank{ntuple(identity, Val{N}())}()
232185

186+
function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
187+
188+
_reshaped_striderank(is_column_major(P), Val{N}(), Val{M}())
189+
end
190+
_reshaped_striderank(::Val{true}, ::Val{N}, ::Val{0}) where {N} = StrideRank{ntuple(identity, Val{N}())}()
191+
_reshaped_striderank(_, __, ___) = nothing
192+
193+
194+
"""
195+
If the contiguous dimension is not the dimension with `StrideRank{1}`:
196+
"""
197+
struct ContiguousBatch{N} end
198+
Base.@pure ContiguousBatch(N::Int) = ContiguousBatch{N}()
199+
_get(::ContiguousBatch{N}) where {N} = N
200+
201+
"""
202+
contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
203+
204+
Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
205+
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
206+
If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
207+
If unknown, it will return `nothing`.
208+
"""
209+
contiguous_batch_size(x) = contiguous_batch_size(typeof(x))
210+
contiguous_batch_size(::Type{T}) where {T} = _contiguous_batch_size(contiguous_axis(T), stride_rank(T))
211+
_contiguous_batch_size(_, __) = nothing
212+
@generated function _contiguous_batch_size(::Contiguous{D}, ::StrideRank{R}) where {D,R}
213+
isone(R[D]) ? :(ContiguousBatch{0}()) : :nothing
214+
end
215+
216+
contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = ContiguousBatch{0}()
217+
contiguous_batch_size(::Type{<:Tuple}) = ContiguousBatch{0}()
218+
contiguous_batch_size(
219+
::Type{<:Union{Transpose{T,A},Adjoint{T,A}}},
220+
) where {T,A<:AbstractVecOrMat{T}} = contiguous_batch_size(A)
221+
contiguous_batch_size(
222+
::Type{<:PermutedDimsArray{T,N,I1,I2,A}},
223+
) where {T,N,I1,I2,A<:AbstractArray{T,N}} = contiguous_batch_size(A)
224+
function contiguous_batch_size(
225+
::Type{S},
226+
) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
227+
_contiguous_batch_size(S, contiguous_batch_size(A), contiguous_axis(A))
228+
end
229+
_contiguous_batch_size(::Any, ::Any, ::Any) = nothing
230+
@generated function _contiguous_batch_size(
231+
::Type{S},
232+
::ContiguousBatch{B},
233+
::Contiguous{C},
234+
) where {B,C,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
235+
if I.parameters[C] <: AbstractUnitRange
236+
Expr(:call, Expr(:curly, :ContiguousBatch, B))
237+
else
238+
Expr(:call, Expr(:curly, :ContiguousBatch, -1))
239+
end
240+
end
241+
242+
contiguous_batch_size(
243+
::Type{R},
244+
) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}} = ContiguousBatch{0}()
245+
246+
233247
"""
234248
is_column_major(A) -> Val{true/false}()
235249
@@ -260,7 +274,8 @@ An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A
260274
"""
261275
dense_dims(x) = dense_dims(typeof(x))
262276
dense_dims(::Type) = nothing
263-
dense_dims(::Type{Array{T,N}}) where {T,N} = DenseDims{ntuple(_ -> true, Val{N}())}()
277+
_all_dense(::Val{N}) where {N} = DenseDims{ntuple(_ -> true, Val{N}())}()
278+
dense_dims(::Type{Array{T,N}}) where {T,N} = _all_dense(Val{N}())
264279
dense_dims(::Type{<:Tuple}) = DenseDims{(true,)}()
265280
function dense_dims(
266281
::Type{<:Union{Transpose{T,A},Adjoint{T,A}}},
@@ -306,6 +321,15 @@ _dense_dims(::Any, ::Any) = nothing
306321
length(dense_tup.args) == N ? Expr(:call, Expr(:curly, :DenseDims, dense_tup)) : nothing
307322
end
308323

324+
function dense_dims(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
325+
326+
_reshaped_dense_dims(dense_dims(P), is_column_major(P), Val{N}(), Val{M}())
327+
end
328+
_reshaped_dense_dims(_, __, ___, ____) = nothing
329+
@generated function _reshaped_dense_dims(::DenseDims{D}, ::Val{true}, ::Val{N}, ::Val{0}) where {D,N}
330+
all(D) ? :(_all_dense(Val{$N}())) : :nothing
331+
end
332+
309333
permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}())
310334
@generated function permute(t::Tuple{Vararg{Any,N}}, ::Val{I}) where {N,I}
311335
t = Expr(:tuple)

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,10 @@ Base.getindex(::DummyZeros{T}, inds...) where {T} = zero(T)
265265

266266
using OffsetArrays
267267
@testset "Memory Layout" begin
268-
A = zeros(3,4,5);
268+
x = zeros(100);
269+
# R = reshape(view(x, 1:100), (10,10));
270+
# A = zeros(3,4,5);
271+
A = reshape(view(x, 1:60), (3,4,5))
269272
D1 = view(A, 1:2:3, :, :) # first dimension is discontiguous
270273
D2 = view(A, :, 2:2:4, :) # first dimension is contiguous
271274
@test @inferred(device(A)) === ArrayInterface.CPUPointer()

0 commit comments

Comments
 (0)