@@ -604,7 +604,7 @@ struct CPUIndex <: AbstractCPU end
604
604
struct GPU <: AbstractDevice end
605
605
606
606
"""
607
- device(::Type{T})
607
+ device(::Type{T})
608
608
609
609
Indicates the most efficient way to access elements from the collection in low-level code.
610
610
For `GPUArrays`, will return `ArrayInterface.GPU()`.
@@ -615,20 +615,25 @@ Otherwise, returns `nothing`.
615
615
device (A) = device (typeof (A))
616
616
device (:: Type ) = nothing
617
617
device (:: Type{<:Tuple} ) = CPUIndex ()
618
- # Relies on overloading for GPUArrays that have subtyped `StridedArray`.
619
- device (:: Type{<:StridedArray} ) = CPUPointer ()
620
- device (
621
- :: Type{<:SubArray{T,N,A,I}} ,
622
- ) where {T,N,A,I<: Tuple{Vararg{Union{Integer,AbstractRange}}} } = device (A)
623
- device (:: Type{<:SubArray} ) = CPUIndex ()
624
- function device (:: Type{T} ) where {T<: AbstractArray }
625
- P = parent_type (T)
626
- T === P ? CPUIndex () : device (P)
618
+ device (:: Type{T} ) where {T<: Array } = CPUPointer ()
619
+ device (:: Type{T} ) where {T<: AbstractArray } = CPUIndex ()
620
+ device (:: Type{T} ) where {T<: PermutedDimsArray } = device (parent_type (T))
621
+ device (:: Type{T} ) where {T<: Transpose } = device (parent_type (T))
622
+ device (:: Type{T} ) where {T<: Adjoint } = device (parent_type (T))
623
+ device (:: Type{T} ) where {T<: ReinterpretArray } = device (parent_type (T))
624
+ device (:: Type{T} ) where {T<: ReshapedArray } = device (parent_type (T))
625
+ function device (:: Type{T} ) where {T<: SubArray }
626
+ if defines_strides (T)
627
+ return device (parent_type (T))
628
+ else
629
+ return _not_pointer (device (parent_type (T)))
630
+ end
627
631
end
628
-
632
+ _not_pointer (:: CPUPointer ) = CPUIndex ()
633
+ _not_pointer (x) = x
629
634
630
635
"""
631
- defines_strides(::Type{T}) -> Bool
636
+ defines_strides(::Type{T}) -> Bool
632
637
633
638
Is strides(::T) defined?
634
639
"""
@@ -1058,6 +1063,9 @@ function __init__()
1058
1063
stride_rank (parent_type (A))
1059
1064
ArrayInterface. axes (A:: OffsetArrays.OffsetArray ) = Base. axes (A)
1060
1065
ArrayInterface. axes (A:: OffsetArrays.OffsetArray , dim:: Integer ) = Base. axes (A, dim)
1066
+ function ArrayInterface. device (:: Type{T} ) where {T<: OffsetArrays.OffsetArray }
1067
+ return device (parent_type (T))
1068
+ end
1061
1069
end
1062
1070
end
1063
1071
0 commit comments