Skip to content

Commit adeb1d9

Browse files
authored
Merge pull request #152 from JuliaArrays/reinterpretstridelayout
Fix stride_rank and contiguous_axis for reinterpret(reshape,...) arrays.
2 parents 3eed3f6 + 6fa21f4 commit adeb1d9

File tree

3 files changed

+71
-7
lines changed

3 files changed

+71
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.12"
3+
version = "3.1.13"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/stridelayout.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,32 @@ stride_rank(x, i) = stride_rank(x)[i]
218218
function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}}
219219
return nstatic(Val(N))
220220
end
221+
if VERSION v"1.6.0-DEV.1581"
222+
@inline function stride_rank(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
223+
_stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
224+
end
225+
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+,One()),sr)...)
226+
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-,One()), tail(sr))
227+
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
228+
# doesn't currently have a means of representing.
229+
@inline function contiguous_axis(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
230+
_reinterpret_contiguous_axis(stride_rank(B), dense_dims(B), contiguous_axis(B), gt(StaticInt{NB}(), StaticInt{NA}()))
231+
end
232+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::False) = One()
233+
@inline _reinterpret_contiguous_axis(::Any, ::Any, ::Any, ::True) = Zero()
234+
@generated function _reinterpret_contiguous_axis(t::Tuple{One,Vararg{StaticInt,N}}, d::Tuple{True,Vararg{StaticBool,N}}, ::One, ::True) where {N}
235+
for n in 1:N
236+
if t.parameters[n+1].parameters[1] === 2
237+
if d.parameters[n+1] === True
238+
return :(StaticInt{$n}())
239+
else
240+
return :(Zero())
241+
end
242+
end
243+
end
244+
:(Zero())
245+
end
246+
end
221247

222248
function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
223249
_reshaped_striderank(is_column_major(P), Val{N}(), Val{M}())

test/runtests.jl

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,17 @@ using OffsetArrays
320320
x = zeros(100);
321321
# R = reshape(view(x, 1:100), (10,10));
322322
# A = zeros(3,4,5);
323-
A = Wrapper(reshape(view(x, 1:60), (3,4,5)))
323+
A = Wrapper(reshape(view(x, 1:60), (3,4,5)));
324324
B = A .== 0;
325-
D1 = view(A, 1:2:3, :, :) # first dimension is discontiguous
326-
D2 = view(A, :, 2:2:4, :) # first dimension is contiguous
325+
D1 = view(A, 1:2:3, :, :); # first dimension is discontiguous
326+
D2 = view(A, :, 2:2:4, :); # first dimension is contiguous
327327

328328
@test @inferred(ArrayInterface.defines_strides(x))
329329
@test @inferred(ArrayInterface.defines_strides(A))
330330
@test @inferred(ArrayInterface.defines_strides(D1))
331331
@test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1)))
332332
@test @inferred(ArrayInterface.defines_strides(DenseWrapper{Int,2,Matrix{Int}}))
333-
333+
334334
@test @inferred(device(A)) === ArrayInterface.CPUPointer()
335335
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
336336
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
@@ -346,7 +346,6 @@ using OffsetArrays
346346
@test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer()
347347
@test isnothing(device("Hello, world!"))
348348
@test @inferred(device(DenseWrapper{Int,2,Matrix{Int}})) === ArrayInterface.CPUPointer()
349-
350349
#=
351350
@btime ArrayInterface.contiguous_axis($(reshape(view(zeros(100), 1:60), (3,4,5))))
352351
0.047 ns (0 allocations: 0 bytes)
@@ -372,7 +371,7 @@ using OffsetArrays
372371
@test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
373372
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing
374373
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing
375-
374+
376375
@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
377376
@test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false)
378377
@test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false)
@@ -417,6 +416,8 @@ using OffsetArrays
417416
@test @inferred(stride_rank(DummyZeros(3,4)')) === nothing
418417
@test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
419418
@test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing
419+
420+
420421
#=
421422
@btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2))))
422423
0.047 ns (0 allocations: 0 bytes)
@@ -478,6 +479,43 @@ using OffsetArrays
478479
Am = @MMatrix rand(2,10);
479480
@test @inferred(ArrayInterface.strides(view(Am,1,:))) === (StaticInt(2),)
480481

482+
if VERSION v"1.6.0-DEV.1581" # reinterpret(reshape,...) tests
483+
C1 = reinterpret(reshape, Float64, PermutedDimsArray(Array{Complex{Float64}}(undef, 3,4,5), (2,1,3)));
484+
C2 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(view(A,1:2,:,:), (1,3,2)));
485+
C3 = reinterpret(reshape, Complex{Float64}, PermutedDimsArray(Wrapper(reshape(view(x, 1:24), (2,3,4))), (1,3,2)));
486+
487+
@test @inferred(ArrayInterface.defines_strides(C1))
488+
@test @inferred(ArrayInterface.defines_strides(C2))
489+
@test @inferred(ArrayInterface.defines_strides(C3))
490+
491+
@test @inferred(device(C1)) === ArrayInterface.CPUPointer()
492+
@test @inferred(device(C2)) === ArrayInterface.CPUPointer()
493+
@test @inferred(device(C3)) === ArrayInterface.CPUPointer()
494+
495+
@test @inferred(contiguous_batch_size(C1)) === ArrayInterface.StaticInt(0)
496+
@test @inferred(contiguous_batch_size(C2)) === ArrayInterface.StaticInt(0)
497+
@test @inferred(contiguous_batch_size(C3)) === ArrayInterface.StaticInt(0)
498+
499+
@test @inferred(stride_rank(C1)) == (1,3,2,4)
500+
@test @inferred(stride_rank(C2)) == (2,1)
501+
@test @inferred(stride_rank(C3)) == (2,1)
502+
503+
@test @inferred(contiguous_axis(C1)) === StaticInt(1)
504+
@test @inferred(contiguous_axis(C2)) === StaticInt(0)
505+
@test @inferred(contiguous_axis(C3)) === StaticInt(2)
506+
507+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C1)) == (true,false,false,false)
508+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C2)) == (false,false)
509+
@test @inferred(ArrayInterface.contiguous_axis_indicator(C3)) == (false,true)
510+
511+
@test @inferred(ArrayInterface.is_column_major(C1)) === False()
512+
@test @inferred(ArrayInterface.is_column_major(C2)) === False()
513+
@test @inferred(ArrayInterface.is_column_major(C3)) === False()
514+
515+
@test @inferred(dense_dims(C1)) == (true,true,true,true)
516+
@test @inferred(dense_dims(C2)) == (false,false)
517+
@test @inferred(dense_dims(C3)) == (true,true)
518+
end
481519
end
482520

483521
@testset "Static-Dynamic Size, Strides, and Offsets" begin

0 commit comments

Comments
 (0)