@@ -262,31 +262,29 @@ stride_rank(x, i) = stride_rank(x)[i]
262
262
function stride_rank (:: Type{R} ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} }
263
263
return nstatic (Val (N))
264
264
end
265
- if VERSION ≥ v " 1.6.0-DEV.1581"
266
- @inline function stride_rank (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
265
+ @inline function stride_rank (:: Type{A} ) where {NB,NA,B<: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any,NA,<:Any,B,true} }
267
266
NA == NB ? stride_rank (B) : _stride_rank_reinterpret (stride_rank (B), gt (StaticInt {NB} (), StaticInt {NA} ()))
268
- end
269
- @inline _stride_rank_reinterpret (sr, :: False ) = (One (), map (Base. Fix2 (+ ,One ()),sr)... )
270
- @inline _stride_rank_reinterpret (sr:: Tuple{One,Vararg} , :: True ) = map (Base. Fix2 (- ,One ()), tail (sr))
271
- # if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
272
- # doesn't currently have a means of representing.
273
- @inline function contiguous_axis (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
267
+ end
268
+ @inline _stride_rank_reinterpret (sr, :: False ) = (One (), map (Base. Fix2 (+ , One ()), sr)... )
269
+ @inline _stride_rank_reinterpret (sr:: Tuple{One,Vararg} , :: True ) = map (Base. Fix2 (- , One ()), tail (sr))
270
+ # if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
271
+ # doesn't currently have a means of representing.
272
+ @inline function contiguous_axis (:: Type{A} ) where {NB,NA,B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any,NA,<:Any,B, true} }
274
273
_reinterpret_contiguous_axis (stride_rank (B), dense_dims (B), contiguous_axis (B), gt (StaticInt {NB} (), StaticInt {NA} ()))
275
- end
276
- @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: False ) = One ()
277
- @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: True ) = Zero ()
278
- @generated function _reinterpret_contiguous_axis (t:: Tuple{One,Vararg{StaticInt,N}} , d:: Tuple{True,Vararg{StaticBool,N}} , :: One , :: True ) where {N}
274
+ end
275
+ @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: False ) = One ()
276
+ @inline _reinterpret_contiguous_axis (:: Any , :: Any , :: Any , :: True ) = Zero ()
277
+ @generated function _reinterpret_contiguous_axis (t:: Tuple{One,Vararg{StaticInt,N}} , d:: Tuple{True,Vararg{StaticBool,N}} , :: One , :: True ) where {N}
279
278
for n in 1 : N
280
- if t. parameters[n+ 1 ]. parameters[1 ] === 2
281
- if d. parameters[n+ 1 ] === True
282
- return :(StaticInt {$n} ())
283
- else
284
- return :(Zero ())
279
+ if t. parameters[n+ 1 ]. parameters[1 ] === 2
280
+ if d. parameters[n+ 1 ] === True
281
+ return :(StaticInt {$n} ())
282
+ else
283
+ return :(Zero ())
284
+ end
285
285
end
286
- end
287
286
end
288
287
:(Zero ())
289
- end
290
288
end
291
289
292
290
function stride_rank (:: Type {Base. ReshapedArray{T, N, P, Tuple{Vararg{Base. SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
411
409
function dense_dims (:: Type{S} ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
412
410
return _dense_dims (S, dense_dims (A), Val (stride_rank (A)))
413
411
end
414
- if VERSION ≥ v " 1.6.0-DEV.1581"
415
- @inline function dense_dims (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
416
- ddb = dense_dims (B)
417
- IfElse. ifelse (Static. le (StaticInt (NB), StaticInt (NA)), (True (), ddb... ), Base. tail (ddb))
418
- end
412
+ @inline function dense_dims (:: Type{A} ) where {NB, NA, B <: AbstractArray{<:Any,NB} ,A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true} }
413
+ ddb = dense_dims (B)
414
+ IfElse. ifelse (Static. le (StaticInt (NB), StaticInt (NA)), (True (), ddb... ), Base. tail (ddb))
419
415
end
420
416
421
417
_dense_dims (:: Type{S} , :: Nothing , :: Val{R} ) where {R,N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} } = nothing
@@ -561,70 +557,127 @@ strides(A::StrideIndex) = getfield(A, :strides)
561
557
end
562
558
end
563
559
564
- # Fixes the example of https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
565
- # TODO : Should be generalized to reshaped arrays wrapping more general array types
566
- function strides (A:: ReshapedArray{T,N,P} ) where {T, N, P<: AbstractVector }
567
- if defines_strides (A)
568
- return size_to_strides (size (A), first (strides (parent (A))))
560
+ _is_column_dense (:: A ) where {A<: AbstractArray } =
561
+ defines_strides (A) &&
562
+ (ndims (A) == 0 || Bool (is_dense (A)) && Bool (is_column_major (A)))
563
+
564
+ # Fixes the example of https://github.com/JuliaArrays/ArrayInterfaceCore.jl/issues/160
565
+ function strides (A:: ReshapedArray )
566
+ _is_column_dense (parent (A)) && return size_to_strides (size (A), One ())
567
+ pst = strides (parent (A))
568
+ psz = size (parent (A))
569
+ # Try dimension merging in order (starting from dim1).
570
+ # `sz1` and `st1` are the `size`/`stride` of dim1 after dimension merging.
571
+ # `n` indicates the last merged dimension.
572
+ # note: `st1` should be static if possible
573
+ sz1, st1, n = merge_adjacent_dim (psz, pst)
574
+ n == ndims (A. parent) && return size_to_strides (size (A), st1)
575
+ return _reshaped_strides (size (A), One (), sz1, st1, n, Dims (psz), Dims (pst))
576
+ end
577
+
578
+ @inline function _reshaped_strides (:: Dims{0} , reshaped, msz:: Int , _, :: Int , :: Dims , :: Dims )
579
+ reshaped == msz && return ()
580
+ throw (ArgumentError (" Input is not strided." ))
581
+ end
582
+ function _reshaped_strides (asz:: Dims , reshaped, msz:: Int , mst, n:: Int , apsz:: Dims , apst:: Dims )
583
+ st = reshaped * mst
584
+ reshaped = reshaped * asz[1 ]
585
+ if length (asz) > 1 && reshaped == msz && asz[2 ] != 1
586
+ msz, mst′, n = merge_adjacent_dim (apsz, apst, n + 1 )
587
+ reshaped = 1
588
+ else
589
+ mst′ = Int (mst)
590
+ end
591
+ sts = _reshaped_strides (tail (asz), reshaped, msz, mst′, n, apsz, apst)
592
+ return (st, sts... )
593
+ end
594
+
595
+ merge_adjacent_dim (:: Tuple{} , :: Tuple{} ) = 1 , One (), 0
596
+ merge_adjacent_dim (szs:: Tuple{Any} , sts:: Tuple{Any} ) = Int (szs[1 ]), sts[1 ], 1
597
+ function merge_adjacent_dim (szs:: Tuple , sts:: Tuple )
598
+ if szs[1 ] isa One # Just ignore dimension with size 1
599
+ sz, st, n = merge_adjacent_dim (tail (szs), tail (sts))
600
+ return sz, st, n + 1
601
+ elseif szs[2 ] isa One # Just ignore dimension with size 1
602
+ sz, st, n = merge_adjacent_dim ((szs[1 ], tail (tail (szs))... ), (sts[1 ], tail (tail (sts))... ))
603
+ return sz, st, n + 1
604
+ elseif (szs[1 ], szs[2 ], sts[1 ], sts[2 ]) isa NTuple{4 ,StaticInt} # the check could be done during compiling.
605
+ if sts[2 ] == sts[1 ] * szs[1 ]
606
+ szs′ = (szs[1 ] * szs[2 ], tail (tail (szs))... )
607
+ sts′ = (sts[1 ], tail (tail (sts))... )
608
+ sz, st, n = merge_adjacent_dim (szs′, sts′)
609
+ return sz, st, n + 1
610
+ else
611
+ return Int (szs[1 ]), sts[1 ], 1
612
+ end
613
+ else # the check can't be done during compiling.
614
+ sz, st, n = merge_adjacent_dim (Dims (szs), Dims (sts), 1 )
615
+ if (szs[1 ], sts[1 ]) isa NTuple{2 ,StaticInt} && szs[1 ] != 1
616
+ # But the 1st stride might still be static.
617
+ return sz, sts[1 ], n
618
+ else
619
+ return sz, st, n
620
+ end
621
+ end
622
+ end
623
+
624
+ function merge_adjacent_dim (psz:: Dims{N} , pst:: Dims{N} , n:: Int ) where {N}
625
+ sz, st = psz[n], pst[n]
626
+ while n < N
627
+ szₙ, stₙ = psz[n+ 1 ], pst[n+ 1 ]
628
+ if sz == 1
629
+ sz, st = szₙ, stₙ
630
+ elseif stₙ == st * sz
631
+ sz *= szₙ
632
+ elseif szₙ != 1
633
+ break
634
+ end
635
+ n += 1
636
+ end
637
+ return sz, st, n
638
+ end
639
+
640
+ # `strides` for `Base.ReinterpretArray`
641
+ function strides (A:: Base.ReinterpretArray{T,<:Any,S,<:AbstractArray{S},IsReshaped} ) where {T,S,IsReshaped}
642
+ _is_column_dense (parent (A)) && return size_to_strides (size (A), One ())
643
+ stp = strides (parent (A))
644
+ ET, ES = static (sizeof (T)), static (sizeof (S))
645
+ ET === ES && return stp
646
+ IsReshaped && ET < ES && return (One (), _reinterp_strides (stp, ET, ES)... )
647
+ first (stp) == 1 || throw (ArgumentError (" Parent must be contiguous in the 1st dimension!" ))
648
+ if IsReshaped
649
+ # The wrapper tell us `A`'s parent has static size in dim1.
650
+ # We can make the next stride static if the following dim is still dense.
651
+ sr = stride_rank (parent (A))
652
+ dd = dense_dims (parent (A))
653
+ stp′ = _new_static (stp, sr, dd, ET ÷ ES)
654
+ return _reinterp_strides (tail (stp′), ET, ES)
569
655
else
570
- return Base . strides (A )
656
+ return ( One (), _reinterp_strides ( tail (stp), ET, ES) ... )
571
657
end
572
658
end
573
- function strides (A:: ReshapedArray{T,N,P} ) where {T, N, P}
574
- if defines_strides (A)
575
- return size_to_strides (size (A), static (1 ))
659
+ _new_static (P,_,_,_) = P # This should never be called, just in case.
660
+ @generated function _new_static (p:: P , :: SR , :: DD , :: StaticInt{S} ) where {S,N,P<: NTuple{N,Union{Int,StaticInt}} ,SR<: NTuple{N,StaticInt} ,DD<: NTuple{N,StaticBool} }
661
+ sr = fieldtypes (SR)
662
+ j = findfirst (T -> T () == sr[1 ]()+ 1 , sr)
663
+ if ! isnothing (j) && ! (fieldtype (P, j) <: StaticInt ) && fieldtype (DD, j) === True
664
+ return :(tuple ($ ((i == j ? :(static ($ S)) : :(p[$ i]) for i in 1 : N). .. )))
576
665
else
577
- return Base. strides (A)
578
- end
579
- end
580
-
581
-
582
- @inline bmap (f:: F , t:: Tuple{} , x:: Number ) where {F} = ()
583
- @inline bmap (f:: F , t:: Tuple{T} , x:: Number ) where {F, T} = (f (first (t),x), )
584
- @inline bmap (f:: F , t:: Tuple , x:: Number ) where {F} = (f (first (t),x), bmap (f, Base. tail (t), x)... )
585
- @static if VERSION ≥ v " 1.6.0-DEV.1581"
586
- # from `reinterpret(reshape, ...)`
587
- @inline function strides (A:: Base.ReinterpretArray{R, N, T, B, true} ) where {R,N,T,B}
588
- P = strides (parent (A))
589
- if sizeof (R) == sizeof (T)
590
- P
591
- elseif sizeof (R) > sizeof (T)
592
- x = Base. tail (P)
593
- fx = first (x)
594
- if fx isa Int
595
- (One (), bmap (Base. sdiv_int, Base. tail (x), fx)... )
596
- else
597
- (One (), bmap (÷ , Base. tail (x), fx)... )
598
- end
666
+ return :(p)
667
+ end
668
+ end
669
+ @inline function _reinterp_strides (stp:: Tuple , els:: StaticInt , elp:: StaticInt )
670
+ if elp % els == 0
671
+ N = elp ÷ els
672
+ return map (i -> N * i, stp)
599
673
else
600
- (One (), bmap (* , P, StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
601
- end
602
- end
603
- # plain `reinterpret(...)`
604
- @inline function strides (A:: Base.ReinterpretArray{R, N, T, B, false} ) where {R,N,T,B}
605
- P = strides (parent (A))
606
- if sizeof (R) == sizeof (T)
607
- P
608
- elseif sizeof (R) > sizeof (T)
609
- (first (P), bmap (÷ , Base. tail (P), StaticInt (sizeof (R)) ÷ StaticInt (sizeof (T)))... )
610
- else # sizeof(R) < sizeof(T)
611
- (first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
612
- end
613
- end
614
- else
615
- # plain `reinterpret(...)`
616
- @inline function strides (A:: Base.ReinterpretArray{R, N, T} ) where {R,N,T}
617
- P = strides (parent (A))
618
- if sizeof (R) == sizeof (T)
619
- P
620
- elseif sizeof (R) > sizeof (T)
621
- (first (P), bmap (÷ , Base. tail (P), StaticInt (sizeof (R)) ÷ StaticInt (sizeof (T)))... )
622
- else # sizeof(R) < sizeof(T)
623
- (first (P), bmap (* , Base. tail (P), StaticInt (sizeof (T)) ÷ StaticInt (sizeof (R)))... )
624
- end
625
- end
626
- end
627
- # @inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A))
674
+ return map (stp) do i
675
+ d, r = divrem (elp * i, els)
676
+ iszero (r) || throw (ArgumentError (" Parent's strides could not be exactly divided!" ))
677
+ d
678
+ end
679
+ end
680
+ end
628
681
629
682
strides (:: AbstractRange ) = (One (),)
630
683
function strides (x:: VecAdjTrans )
0 commit comments