Skip to content

Commit f8b80d9

Browse files
Tokazamachriselrod
andauthored
Assume parent array has same strides (#192)
* Assume parent array has same strides * Try to fix reinterpret failures * Tests no longer broken Co-authored-by: Chris Elrod <[email protected]>
1 parent 138b1cc commit f8b80d9

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.24"
3+
version = "3.1.25"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/stridelayout.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,10 @@ while still producing correct behavior when using valid cartesian indices, such
502502
strides(A::StrideIndex) = getfield(A, :strides)
503503
@inline strides(A::Vector{<:Any}) = (StaticInt(1),)
504504
@inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...)
505-
function strides(x)
506-
if defines_strides(x)
505+
@inline function strides(x::X) where {X}
506+
if !(parent_type(X) <: X)
507+
return strides(parent(x))
508+
elseif defines_strides(X)
507509
return size_to_strides(size(x), One())
508510
else
509511
return Base.strides(x)
@@ -519,11 +521,19 @@ function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector}
519521
return Base.strides(A)
520522
end
521523
end
524+
function strides(A::ReshapedArray{T,N,P}) where {T, N, P}
525+
if defines_strides(A)
526+
return size_to_strides(size(A), static(1))
527+
else
528+
return Base.strides(A)
529+
end
530+
end
531+
522532

523533
@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
524534
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
525535
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)
526-
if VERSION v"1.6.0-DEV.1581"
536+
@static if VERSION v"1.6.0-DEV.1581"
527537
# from `reinterpret(reshape, ...)`
528538
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, true}) where {R,N,T,B}
529539
P = strides(parent(A))
@@ -541,7 +551,6 @@ if VERSION ≥ v"1.6.0-DEV.1581"
541551
(One(), bmap(*, P, StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
542552
end
543553
end
544-
545554
# plain `reinterpret(...)`
546555
@inline function strides(A::Base.ReinterpretArray{R, N, T, B, false}) where {R,N,T,B}
547556
P = strides(parent(A))
@@ -553,6 +562,18 @@ if VERSION ≥ v"1.6.0-DEV.1581"
553562
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
554563
end
555564
end
565+
else
566+
# plain `reinterpret(...)`
567+
@inline function strides(A::Base.ReinterpretArray{R, N, T}) where {R,N,T}
568+
P = strides(parent(A))
569+
if sizeof(R) == sizeof(T)
570+
P
571+
elseif sizeof(R) > sizeof(T)
572+
(first(P), bmap(÷, Base.tail(P), StaticInt(sizeof(R)) ÷ StaticInt(sizeof(T)))...)
573+
else # sizeof(R) < sizeof(T)
574+
(first(P), bmap(*, Base.tail(P), StaticInt(sizeof(T)) ÷ StaticInt(sizeof(R)))...)
575+
end
576+
end
556577
end
557578
#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
558579

test/runtests.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ end
767767
end
768768

769769
@testset "Reinterpreted reshaped views" begin
770-
u_base = randn(1, 4, 4, 5)
770+
u_base = randn(1, 4, 4, 5);
771771
u_vectors = reshape(reinterpret(SVector{1, eltype(u_base)}, u_base),
772772
Base.tail(size(u_base))...)
773773
u_view = view(u_vectors, 2, :, 3)
@@ -778,13 +778,8 @@ end
778778
@test @inferred(ArrayInterface.strides(u_base)) == (StaticInt(1), 1, 4, 16)
779779
@test @inferred(ArrayInterface.strides(u_vectors)) == (StaticInt(1), 4, 16)
780780
@test @inferred(ArrayInterface.strides(u_view)) == (4,)
781-
if VERSION v"1.6.0-DEV.1581"
782-
@test @inferred(ArrayInterface.strides(u_view_reinterpreted)) == (4,)
783-
@test @inferred(ArrayInterface.strides(u_view_reshaped)) == (4, 4)
784-
else
785-
@test_broken @inferred(ArrayInterface.strides(u_view_reinterpreted)) == (4,)
786-
@test_broken @inferred(ArrayInterface.strides(u_view_reshaped)) == (4, 4)
787-
end
781+
@test @inferred(ArrayInterface.strides(u_view_reinterpreted)) == (4,)
782+
@test @inferred(ArrayInterface.strides(u_view_reshaped)) == (4, 4)
788783
end
789784

790785
@test ArrayInterface.can_avx(ArrayInterface.can_avx) == false

0 commit comments

Comments
 (0)