Skip to content

Commit b39bd6a

Browse files
committed
Make SubArrays that aren't StridedArrays be CPUIndex.
1 parent 142345a commit b39bd6a

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.14.1"
3+
version = "2.14.2"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/ArrayInterface.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,13 @@ device(::Type) = nothing
567567
device(::Type{<:Tuple}) = CPUIndex()
568568
# Relies on overloading for GPUArrays that have subtyped `StridedArray`.
569569
device(::Type{<:StridedArray}) = CPUPointer()
570+
function device(::Type{T}) where {T <: SubArray}
571+
if T <: StridedArray
572+
device(parent_type(T))
573+
else
574+
CPUIndex()
575+
end
576+
end
570577
function device(::Type{T}) where {T <: AbstractArray}
571578
P = parent_type(T)
572579
T === P ? CPUIndex() : device(P)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ Base.getindex(::DummyZeros{T}, inds...) where {T} = zero(T)
272272
@test device(PermutedDimsArray(A,(3,1,2))) === ArrayInterface.CPUPointer()
273273
@test device(view(A, 1, :, 2:4)) === ArrayInterface.CPUPointer()
274274
@test device(view(A, 1, :, 2:4)') === ArrayInterface.CPUPointer()
275+
@test device(view(A, 1, :, [2,3,4])) === ArrayInterface.CPUIndex()
276+
@test device(view(A, 1, :, [2,3,4])') === ArrayInterface.CPUIndex()
275277
@test device(@SArray(zeros(2,2,2))) === ArrayInterface.CPUIndex()
276278
@test device(@view(@SArray(zeros(2,2,2))[1,1:2,:])) === ArrayInterface.CPUIndex()
277279
@test device(@MArray(zeros(2,2,2))) === ArrayInterface.CPUPointer()

0 commit comments

Comments
 (0)