Skip to content

Commit 4ca0984

Browse files
authored
Merge pull request #143 from JuliaArrays/reinterpretstridesthroughparent
Fix ReinterpretArray's definition of strides
2 parents cd6e5d2 + 1194195 commit 4ca0984

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

src/stridelayout.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,27 @@ function strides(x)
443443
return Base.strides(x)
444444
end
445445
end
446+
@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
447+
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
448+
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)
449+
if VERSION v"1.6.0-DEV.1581"
450+
@inline @inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B}
451+
P = strides(parent(A))
452+
if sizeof(R) == sizeof(T)
453+
P
454+
elseif sizeof(R) > sizeof(T)
455+
x = Base.tail(P)
456+
fx = first(x)
457+
if fx isa Int
458+
(One(), bmap(Base.sdiv_int, Base.tail(x), fx)...)
459+
else
460+
(One(), bmap(÷, Base.tail(x), fx)...)
461+
end
462+
else
463+
(One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
464+
end
465+
end
466+
end
446467
#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
447468

448469
strides(::AbstractRange) = (One(),)

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,22 @@ end
647647
@test @inferred(ArrayInterface.dense_dims(view(Sr2,:,2))) === (True(),)
648648
@test @inferred(ArrayInterface.dense_dims(view(Sr2,:,2:3))) === (True(),True())
649649
@test @inferred(ArrayInterface.dense_dims(view(Sr2,2:3,:))) === (True(),False())
650+
651+
Ar2c = reinterpret(reshape, Complex{Float64}, view(rand(2, 5, 7), :, 2:4, 3:5));
652+
@test @inferred(ArrayInterface.strides(Ar2c)) === (StaticInt(1), 5)
653+
Ar2c_static = reinterpret(reshape, Complex{Float64}, view(@MArray(rand(2, 5, 7)), :, 2:4, 3:5));
654+
@test @inferred(ArrayInterface.strides(Ar2c_static)) === (StaticInt(1), StaticInt(5))
655+
656+
Ac2r = reinterpret(reshape, Float64, view(rand(ComplexF64, 5, 7), 2:4, 3:6));
657+
@test @inferred(ArrayInterface.strides(Ac2r)) === (StaticInt(1), StaticInt(2), 10)
658+
Ac2r_static = reinterpret(reshape, Float64, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
659+
@test @inferred(ArrayInterface.strides(Ac2r_static)) === (StaticInt(1), StaticInt(2), StaticInt(10))
660+
661+
Ac2t = reinterpret(reshape, Tuple{Float64,Float64}, view(rand(ComplexF64, 5, 7), 2:4, 3:6));
662+
@test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5)
663+
Ac2t_static = reinterpret(reshape, Tuple{Float64,Float64}, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
664+
@test @inferred(ArrayInterface.strides(Ac2t_static)) === (StaticInt(1), StaticInt(5))
665+
650666
end
651667
end
652668

0 commit comments

Comments
 (0)