@@ -502,8 +502,10 @@ while still producing correct behavior when using valid cartesian indices, such
502
502
strides (A:: StrideIndex ) = getfield (A, :strides )
503
503
@inline strides (A:: Vector{<:Any} ) = (StaticInt (1 ),)
504
504
@inline strides (A:: Array{<:Any,N} ) where {N} = (StaticInt (1 ), Base. tail (Base. strides (A))... )
505
- function strides (x)
506
- if defines_strides (x)
505
+ @inline function strides (x:: X ) where {X}
506
+ if ! (parent_type (X) <: X )
507
+ return strides (parent (x))
508
+ elseif defines_strides (X)
507
509
return size_to_strides (size (x), One ())
508
510
else
509
511
return Base. strides (x)
@@ -519,11 +521,19 @@ function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector}
519
521
return Base. strides (A)
520
522
end
521
523
end
524
+ function strides (A:: ReshapedArray{T,N,P} ) where {T, N, P}
525
+ if defines_strides (A)
526
+ return size_to_strides (size (A), static (1 ))
527
+ else
528
+ return Base. strides (A)
529
+ end
530
+ end
531
+
522
532
523
533
@inline bmap (f:: F , t:: Tuple{} , x:: Number ) where {F} = ()
524
534
@inline bmap (f:: F , t:: Tuple{T} , x:: Number ) where {F, T} = (f (first (t),x), )
525
535
@inline bmap (f:: F , t:: Tuple , x:: Number ) where {F} = (f (first (t),x), bmap (f, Base. tail (t), x)... )
526
- if VERSION ≥ v " 1.6.0-DEV.1581"
536
+ @static if VERSION ≥ v " 1.6.0-DEV.1581"
527
537
# from `reinterpret(reshape, ...)`
528
538
@inline function strides (A:: Base.ReinterpretArray{R, N, T, B, true} ) where {R,N,T,B}
529
539
P = strides (parent (A))
@@ -541,7 +551,6 @@ if VERSION ≥ v"1.6.0-DEV.1581"
541
551
(One (), bmap (* , P, StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
542
552
end
543
553
end
544
-
545
554
# plain `reinterpret(...)`
546
555
@inline function strides (A:: Base.ReinterpretArray{R, N, T, B, false} ) where {R,N,T,B}
547
556
P = strides (parent (A))
@@ -553,6 +562,18 @@ if VERSION ≥ v"1.6.0-DEV.1581"
553
562
(first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
554
563
end
555
564
end
565
+ else
566
+ # plain `reinterpret(...)`
567
+ @inline function strides (A:: Base.ReinterpretArray{R, N, T} ) where {R,N,T}
568
+ P = strides (parent (A))
569
+ if sizeof (R) == sizeof (T)
570
+ P
571
+ elseif sizeof (R) > sizeof (T)
572
+ (first (P), bmap (÷ , Base. tail (P), StaticInt (sizeof (R)) ÷ StaticInt (sizeof (T)))... )
573
+ else # sizeof(R) < sizeof(T)
574
+ (first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
575
+ end
576
+ end
556
577
end
557
578
# @inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
558
579
0 commit comments