@@ -12,6 +12,9 @@ import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex
12
12
import ArrayInterfaceCore: ismutable, can_change_size, can_setindex, deleteat, insert
13
13
# constants
14
14
import ArrayInterfaceCore: MatAdjTrans, VecAdjTrans, UpTri, LoTri
15
+ #
16
+ import ArrayInterfaceCore: AbstractDevice, AbstractCPU, CPUPointer, CPUTuple, CheckParent,
17
+ CPUIndex, GPU
15
18
16
19
using Static
17
20
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
71
74
@inline static_last (x) = Static. maybe_static (known_last, last, x)
72
75
@inline static_step (x) = Static. maybe_static (known_step, step, x)
73
76
77
+ """
78
+ device(::Type{T}) -> AbstractDevice
79
+
80
+ Indicates the most efficient way to access elements from the collection in low-level code.
81
+ For `GPUArrays`, will return `ArrayInterfaceCore.GPU()`.
82
+ For `AbstractArray` supporting a `pointer` method, returns `ArrayInterfaceCore.CPUPointer()`.
83
+ For other `AbstractArray`s and `Tuple`s, returns `ArrayInterfaceCore.CPUIndex()`.
84
+ Otherwise, returns `nothing`.
85
+ """
86
+ device (A) = device (typeof (A))
87
+ device (:: Type ) = nothing
88
+ device (:: Type{<:Tuple} ) = CPUTuple ()
89
+ device (:: Type{T} ) where {T<: Array } = CPUPointer ()
90
+ device (:: Type{T} ) where {T<: AbstractArray } = _device (has_parent (T), T)
91
+ function _device (:: True , :: Type{T} ) where {T}
92
+ if defines_strides (T)
93
+ return device (parent_type (T))
94
+ else
95
+ return _not_pointer (device (parent_type (T)))
96
+ end
97
+ end
98
+ _not_pointer (:: CPUPointer ) = CPUIndex ()
99
+ _not_pointer (x) = x
100
+ _device (:: False , :: Type{T} ) where {T<: DenseArray } = CPUPointer ()
101
+ _device (:: False , :: Type{T} ) where {T} = CPUIndex ()
102
+
74
103
include (" array_index.jl" )
75
104
include (" axes.jl" )
76
105
include (" broadcast.jl" )
0 commit comments