Skip to content

Commit 470c7f8

Browse files
committed
Add CPUTuple <: AbstractCPU as new device type
1 parent 7ce4be8 commit 470c7f8

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/ArrayInterface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ end
615615
abstract type AbstractDevice end
616616
abstract type AbstractCPU <: AbstractDevice end
617617
struct CPUPointer <: AbstractCPU end
618+
struct CPUTuple <: AbstractCPU end
618619
struct CheckParent end
619620
struct CPUIndex <: AbstractCPU end
620621
struct GPU <: AbstractDevice end
@@ -630,7 +631,7 @@ Otherwise, returns `nothing`.
630631
"""
631632
device(A) = device(typeof(A))
632633
device(::Type) = nothing
633-
device(::Type{<:Tuple}) = CPUIndex()
634+
device(::Type{<:Tuple}) = CPUTuple()
634635
device(::Type{T}) where {T<:Array} = CPUPointer()
635636
device(::Type{T}) where {T<:AbstractArray} = _device(has_parent(T), T)
636637
function _device(::True, ::Type{T}) where {T}
@@ -880,6 +881,7 @@ function __init__()
880881
known_length(::Type{A}) where {A <: StaticArrays.StaticArray} = known_length(StaticArrays.Length(A))
881882

882883
device(::Type{<:StaticArrays.MArray}) = CPUPointer()
884+
device(::Type{<:StaticArrays.SArray}) = CPUTuple()
883885
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
884886
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
885887
function stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}}

0 commit comments

Comments
 (0)