Skip to content

Commit 6c35f03

Browse files
committed
Fix stride_rank and contiguous_axis for reinterpret(reshape,...) arrays.
1 parent 3eed3f6 commit 6c35f03

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

src/stridelayout.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,32 @@ stride_rank(x, i) = stride_rank(x)[i]
218218
function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}}
219219
return nstatic(Val(N))
220220
end
221+
if VERSION v"1.6.0-DEV.1581"
222+
@inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
223+
_stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
224+
end
225+
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...)
226+
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr))
227+
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
228+
# doesn't currently have a means of representing.
229+
@inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
230+
_reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}()))
231+
end
232+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One()
233+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero()
234+
@generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N}
235+
for n in 1:N
236+
if t.parameters[n+1].parameters[1] === 2
237+
if d.parameters[n+1] === True
238+
return :(StaticInt{$n}())
239+
else
240+
return :(Zero())
241+
end
242+
end
243+
end
244+
:(Zero())
245+
end
246+
end
221247

222248
function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
223249
_reshaped_striderank(is_column_major(P), Val{N}(), Val{M}())

test/runtests.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,23 @@ using OffsetArrays
320320
x = zeros(100);
321321
# R = reshape(view(x, 1:100), (10,10));
322322
# A = zeros(3,4,5);
323-
A = Wrapper(reshape(view(x, 1:60), (3,4,5)))
323+
A = Wrapper(reshape(view(x, 1:60), (3,4,5)));
324324
B = A .== 0;
325-
D1 = view(A, 1:2:3, :, :) # first dimension is discontiguous
326-
D2 = view(A, :, 2:2:4, :) # first dimension is contiguous
325+
D1 = view(A, 1:2:3, :, :); # first dimension is discontiguous
326+
D2 = view(A, :, 2:2:4, :); # first dimension is contiguous
327+
C1 = reinterpret(reshape, Float64, PermutedDimsArray(Array{Complex{Float64}}(undef, 3,4,5), (2,1,3)));
328+
C2 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(view(A,1:2,:,:), (1,3,2)));
329+
C3 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(Wrapper(reshape(view(x, 1:24), (2,3,4))), (1,3,2)));
327330

328331
@test @inferred(ArrayInterface.defines_strides(x))
329332
@test @inferred(ArrayInterface.defines_strides(A))
330333
@test @inferred(ArrayInterface.defines_strides(D1))
331334
@test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1)))
332335
@test @inferred(ArrayInterface.defines_strides(DenseWrapper{Int,2,Matrix{Int}}))
333-
336+
@test @inferred(ArrayInterface.defines_strides(C1))
337+
@test @inferred(ArrayInterface.defines_strides(C2))
338+
@test @inferred(ArrayInterface.defines_strides(C3))
339+
334340
@test @inferred(device(A)) === ArrayInterface.CPUPointer()
335341
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
336342
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
@@ -346,7 +352,9 @@ using OffsetArrays
346352
@test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer()
347353
@test isnothing(device("Hello, world!"))
348354
@test @inferred(device(DenseWrapper{Int,2,Matrix{Int}})) === ArrayInterface.CPUPointer()
349-
355+
@test @inferred(device(C1)) === ArrayInterface.CPUPointer()
356+
@test @inferred(device(C2)) === ArrayInterface.CPUPointer()
357+
@test @inferred(device(C3)) === ArrayInterface.CPUPointer()
350358
#=
351359
@btime ArrayInterface.contiguous_axis($(reshape(view(zeros(100), 1:60), (3,4,5))))
352360
0.047 ns (0 allocations: 0 bytes)
@@ -372,7 +380,10 @@ using OffsetArrays
372380
@test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
373381
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing
374382
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing
375-
383+
@test @inferred(contiguous_axis(C1)) === StaticInt(1)
384+
@test @inferred(contiguous_axis(C2)) === StaticInt(0)
385+
@test @inferred(contiguous_axis(C3)) === StaticInt(2)
386+
376387
@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
377388
@test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false)
378389
@test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false)
@@ -386,6 +397,9 @@ using OffsetArrays
386397
@test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == (true,false)
387398
@test @inferred(ArrayInterface.contiguous_axis_indicator(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) == (false,true,false)
388399
@test @inferred(ArrayInterface.contiguous_axis_indicator(DummyZeros(3,4))) === nothing
400+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C1)) == (true,false,false,false)
401+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C2)) == (false,false)
402+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C3)) == (false,true)
389403

390404
@test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(0)
391405
@test @inferred(contiguous_batch_size(A)) === ArrayInterface.StaticInt(0)
@@ -398,6 +412,9 @@ using OffsetArrays
398412
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StaticInt(-1)
399413
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1)
400414
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(0)
415+
@test @inferred(contiguous_batch_size(C1)) === ArrayInterface.StaticInt(0)
416+
@test @inferred(contiguous_batch_size(C2)) === ArrayInterface.StaticInt(0)
417+
@test @inferred(contiguous_batch_size(C3)) === ArrayInterface.StaticInt(0)
401418

402419
@test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == (1, 2, 3)
403420
@test @inferred(stride_rank(A)) == (1,2,3)
@@ -417,6 +434,11 @@ using OffsetArrays
417434
@test @inferred(stride_rank(DummyZeros(3,4)')) === nothing
418435
@test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
419436
@test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing
437+
@test @inferred(stride_rank(C1)) == (1,3,2,4)
438+
@test @inferred(stride_rank(C2)) == (2,1)
439+
@test @inferred(stride_rank(C3)) == (2,1)
440+
441+
420442
#=
421443
@btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2))))
422444
0.047 ns (0 allocations: 0 bytes)
@@ -442,6 +464,9 @@ using OffsetArrays
442464
@test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === True()
443465
@test @inferred(ArrayInterface.is_column_major(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) === False()
444466
@test @inferred(ArrayInterface.is_column_major(2.3)) === False()
467+
@test @inferred(ArrayInterface.is_column_major(C1)) === False()
468+
@test @inferred(ArrayInterface.is_column_major(C2)) === False()
469+
@test @inferred(ArrayInterface.is_column_major(C3)) === False()
445470

446471
@test @inferred(dense_dims(@SArray(zeros(2,2,2)))) == (true,true,true)
447472
@test @inferred(dense_dims(A)) == (true,true,true)
@@ -467,6 +492,9 @@ using OffsetArrays
467492
@test @inferred(dense_dims(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
468493
@test @inferred(dense_dims(view(DummyZeros(3,4), :, 1))) === nothing
469494
@test @inferred(dense_dims(view(DummyZeros(3,4), :, 1)')) === nothing
495+
@test @inferred(dense_dims(C1)) == (true,true,true,true)
496+
@test @inferred(dense_dims(C2)) == (false,false)
497+
@test @inferred(dense_dims(C3)) == (true,true)
470498

471499
C = Array{Int8}(undef, 2,2,2,2);
472500
doubleperm = PermutedDimsArray(PermutedDimsArray(C,(4,2,3,1)), (4,2,1,3));

0 commit comments

Comments
 (0)