Skip to content

Commit f05b8bd

Browse files
authored
Merge pull request #131 from JuliaArrays/cputuple
Add CPUTuple <: AbstractCPU as new device type
2 parents 7ce4be8 + b5de99d commit f05b8bd

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/ArrayInterface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ end
615615
abstract type AbstractDevice end
616616
abstract type AbstractCPU <: AbstractDevice end
617617
struct CPUPointer <: AbstractCPU end
618+
struct CPUTuple <: AbstractCPU end
618619
struct CheckParent end
619620
struct CPUIndex <: AbstractCPU end
620621
struct GPU <: AbstractDevice end
@@ -630,7 +631,7 @@ Otherwise, returns `nothing`.
630631
"""
631632
device(A) = device(typeof(A))
632633
device(::Type) = nothing
633-
device(::Type{<:Tuple}) = CPUIndex()
634+
device(::Type{<:Tuple}) = CPUTuple()
634635
device(::Type{T}) where {T<:Array} = CPUPointer()
635636
device(::Type{T}) where {T<:AbstractArray} = _device(has_parent(T), T)
636637
function _device(::True, ::Type{T}) where {T}
@@ -880,6 +881,7 @@ function __init__()
880881
known_length(::Type{A}) where {A <: StaticArrays.StaticArray} = known_length(StaticArrays.Length(A))
881882

882883
device(::Type{<:StaticArrays.MArray}) = CPUPointer()
884+
device(::Type{<:StaticArrays.SArray}) = CPUTuple()
883885
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
884886
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
885887
function stride_rank(::Type{T}) where {N,T<:StaticArrays.StaticArray{<:Any,<:Any,N}}

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,15 +320,15 @@ using OffsetArrays
320320
@test @inferred(device(A)) === ArrayInterface.CPUPointer()
321321
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
322322
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
323-
@test @inferred(device((1,2,3))) === ArrayInterface.CPUIndex()
323+
@test @inferred(device((1,2,3))) === ArrayInterface.CPUTuple()
324324
@test @inferred(device(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.CPUPointer()
325325
@test @inferred(device(view(A, 1, :, 2:4))) === ArrayInterface.CPUPointer()
326326
@test @inferred(device(view(A, 1, :, 2:4)')) === ArrayInterface.CPUPointer()
327327
@test @inferred(device(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173))) === ArrayInterface.CPUPointer()
328328
@test @inferred(device(view(OffsetArray(A,2,3,-12), 4, :, -11:-9))) === ArrayInterface.CPUPointer()
329329
@test @inferred(device(view(OffsetArray(A,2,3,-12), 3, :, [-11,-10,-9])')) === ArrayInterface.CPUIndex()
330-
@test @inferred(device(OffsetArray(@SArray(zeros(2,2,2)),-123,29,3231))) === ArrayInterface.CPUIndex()
331-
@test @inferred(device(OffsetArray(@view(@SArray(zeros(2,2,2))[1,1:2,:]),-3,4))) === ArrayInterface.CPUIndex()
330+
@test @inferred(device(OffsetArray(@SArray(zeros(2,2,2)),-123,29,3231))) === ArrayInterface.CPUTuple()
331+
@test @inferred(device(OffsetArray(@view(@SArray(zeros(2,2,2))[1,1:2,:]),-3,4))) === ArrayInterface.CPUTuple()
332332
@test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer()
333333
@test isnothing(device("Hello, world!"))
334334
@test @inferred(device(DenseWrapper{Int,2,Matrix{Int}})) === ArrayInterface.CPUPointer()

0 commit comments

Comments
 (0)