Skip to content

Commit bc560cf

Browse files
authored
Merge pull request #123 from SciML/device
Fix `device`
2 parents f957da5 + f7efa58 commit bc560cf

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

src/ArrayInterface.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ struct CPUIndex <: AbstractCPU end
604604
struct GPU <: AbstractDevice end
605605

606606
"""
607-
device(::Type{T})
607+
device(::Type{T})
608608
609609
Indicates the most efficient way to access elements from the collection in low-level code.
610610
For `GPUArrays`, will return `ArrayInterface.GPU()`.
@@ -615,20 +615,25 @@ Otherwise, returns `nothing`.
615615
device(A) = device(typeof(A))
616616
device(::Type) = nothing
617617
device(::Type{<:Tuple}) = CPUIndex()
618-
# Relies on overloading for GPUArrays that have subtyped `StridedArray`.
619-
device(::Type{<:StridedArray}) = CPUPointer()
620-
device(
621-
::Type{<:SubArray{T,N,A,I}},
622-
) where {T,N,A,I<:Tuple{Vararg{Union{Integer,AbstractRange}}}} = device(A)
623-
device(::Type{<:SubArray}) = CPUIndex()
624-
function device(::Type{T}) where {T<:AbstractArray}
625-
P = parent_type(T)
626-
T === P ? CPUIndex() : device(P)
618+
device(::Type{T}) where {T<:Array} = CPUPointer()
619+
device(::Type{T}) where {T<:AbstractArray} = CPUIndex()
620+
device(::Type{T}) where {T<:PermutedDimsArray} = device(parent_type(T))
621+
device(::Type{T}) where {T<:Transpose} = device(parent_type(T))
622+
device(::Type{T}) where {T<:Adjoint} = device(parent_type(T))
623+
device(::Type{T}) where {T<:ReinterpretArray} = device(parent_type(T))
624+
device(::Type{T}) where {T<:ReshapedArray} = device(parent_type(T))
625+
function device(::Type{T}) where {T<:SubArray}
626+
if defines_strides(T)
627+
return device(parent_type(T))
628+
else
629+
return _not_pointer(device(parent_type(T)))
630+
end
627631
end
628-
632+
_not_pointer(::CPUPointer) = CPUIndex()
633+
_not_pointer(x) = x
629634

630635
"""
631-
defines_strides(::Type{T}) -> Bool
636+
defines_strides(::Type{T}) -> Bool
632637
633638
Is strides(::T) defined?
634639
"""
@@ -1058,6 +1063,9 @@ function __init__()
10581063
stride_rank(parent_type(A))
10591064
ArrayInterface.axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
10601065
ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
1066+
function ArrayInterface.device(::Type{T}) where {T<:OffsetArrays.OffsetArray}
1067+
return device(parent_type(T))
1068+
end
10611069
end
10621070
end
10631071

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ struct Wrapper{T,N,P<:AbstractArray{T,N}} <: ArrayInterface.AbstractArray2{T,N}
296296
end
297297
ArrayInterface.parent_type(::Type{<:Wrapper{T,N,P}}) where {T,N,P} = P
298298
Base.parent(x::Wrapper) = x.parent
299+
ArrayInterface.device(::Type{T}) where {T<:Wrapper} = ArrayInterface.device(parent_type(T))
299300

300301
using OffsetArrays
301302
@testset "Memory Layout" begin

0 commit comments

Comments
 (0)