Skip to content

Commit 59f479f

Browse files
committed
Move AbsractDevice related methods to ArrayInterfaceCore
These don't need `StaticInt` and are semantically related to `AbstractDevice` types in ArrayInterfaceCore: * `defines_strides` * `device` * `stride_preserving_index`
1 parent 6a4dff1 commit 59f479f

File tree

6 files changed

+62
-60
lines changed

6 files changed

+62
-60
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.21"
3+
version = "6.0.22"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
ArrayInterfaceCore.can_avx
99
ArrayInterfaceCore.can_change_size
1010
ArrayInterfaceCore.can_setindex
11+
ArrayInterfaceCore.device
12+
ArrayInterfaceCore.defines_strides
1113
ArrayInterfaceCore.fast_matrix_colors
1214
ArrayInterfaceCore.fast_scalar_indexing
1315
ArrayInterfaceCore.is_forwarding_wrapper
@@ -52,8 +54,6 @@ ArrayInterfaceCore.SetIndex!
5254
ArrayInterface.contiguous_axis
5355
ArrayInterface.contiguous_axis_indicator
5456
ArrayInterface.contiguous_batch_size
55-
ArrayInterface.defines_strides
56-
ArrayInterface.device
5757
ArrayInterface.dimnames
5858
ArrayInterface.has_dimnames
5959
ArrayInterface.has_parent

lib/ArrayInterfaceCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterfaceCore"
22
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
3-
version = "0.1.16"
3+
version = "0.1.17"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,32 @@ struct CheckParent end
489489
struct CPUIndex <: AbstractCPU end
490490
struct GPU <: AbstractDevice end
491491

492+
"""
493+
device(::Type{T}) -> AbstractDevice
494+
495+
Indicates the most efficient way to access elements from the collection in low-level code.
496+
For `GPUArrays`, will return `ArrayInterface.GPU()`.
497+
For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`.
498+
For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`.
499+
Otherwise, returns `nothing`.
500+
"""
501+
device(A) = device(typeof(A))
502+
device(::Type) = nothing
503+
device(::Type{<:Tuple}) = CPUTuple()
504+
device(::Type{T}) where {T<:Array} = CPUPointer()
505+
device(::Type{T}) where {T<:AbstractArray} = _device(parent_type(T), T)
506+
function _device(::Type{P}, ::Type{T}) where {P,T}
507+
if defines_strides(T)
508+
return device(P)
509+
else
510+
return _not_pointer(device(P))
511+
end
512+
end
513+
_not_pointer(::CPUPointer) = CPUIndex()
514+
_not_pointer(x) = x
515+
_device(::Type{T}, ::Type{T}) where {T<:DenseArray} = CPUPointer()
516+
_device(::Type{T}, ::Type{T}) where {T} = CPUIndex()
517+
492518
"""
493519
can_avx(f) -> Bool
494520
@@ -836,4 +862,33 @@ indices_do_not_alias(::Type{Transpose{T,A}}) where {T, A <: AbstractArray{T}} =
836862
indices_do_not_alias(::Type{<:SubArray{<:Any,<:Any,A,I}}) where {
837863
A,I<:Tuple{Vararg{Union{Integer, UnitRange, Base.ReshapedUnitRange, Base.AbstractCartesianIndex}}}} = indices_do_not_alias(A)
838864

865+
"""
866+
defines_strides(::Type{T}) -> Bool
867+
868+
Is strides(::T) defined? It is assumed that types returning `true` also return a valid
869+
pointer on `pointer(::T)`.
870+
"""
871+
defines_strides(x) = defines_strides(typeof(x))
872+
_defines_strides(::Type{T}, ::Type{T}) where {T} = false
873+
_defines_strides(::Type{P}, ::Type{T}) where {P,T} = defines_strides(P)
874+
defines_strides(::Type{T}) where {T} = _defines_strides(parent_type(T), T)
875+
defines_strides(@nospecialize T::Type{<:StridedArray}) = true
876+
defines_strides(@nospecialize T::Type{<:BitArray}) = true
877+
@inline function defines_strides(@nospecialize T::Type{<:SubArray})
878+
stride_preserving_index(fieldtype(T, :indices))
879+
end
880+
881+
#=
882+
stride_preserving_index(::Type{T}) -> Bool
883+
884+
Returns `True` if strides between each element can still be derived when indexing with an
885+
instance of type `T`.
886+
=#
887+
stride_preserving_index(@nospecialize T::Type{<:AbstractRange}) = true
888+
stride_preserving_index(@nospecialize T::Type{<:Number}) = true
889+
@inline function stride_preserving_index(@nospecialize T::Type{<:Tuple})
890+
all(map_tuple_type(stride_preserving_index, T))
891+
end
892+
stride_preserving_index(@nospecialize T::Type) = false
893+
839894
end # module

src/ArrayInterface.jl

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
88
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo,
9-
map_tuple_type, flatten_tuples, GetIndex, SetIndex!
9+
map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides,
10+
stride_preserving_index
1011

1112
# ArrayIndex subtypes and methods
1213
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex
@@ -16,7 +17,7 @@ import ArrayInterfaceCore: ismutable, can_change_size, can_setindex
1617
import ArrayInterfaceCore: MatAdjTrans, VecAdjTrans, UpTri, LoTri
1718
# device pieces
1819
import ArrayInterfaceCore: AbstractDevice, AbstractCPU, CPUPointer, CPUTuple, CheckParent,
19-
CPUIndex, GPU, can_avx
20+
CPUIndex, GPU, can_avx, device
2021

2122
import ArrayInterfaceCore: known_first, known_step, known_last
2223

@@ -109,32 +110,6 @@ has_parent(::Type{T}) where {T} = _has_parent(parent_type(T), T)
109110
_has_parent(::Type{T}, ::Type{T}) where {T} = False()
110111
_has_parent(::Type{T1}, ::Type{T2}) where {T1,T2} = True()
111112

112-
"""
113-
device(::Type{T}) -> AbstractDevice
114-
115-
Indicates the most efficient way to access elements from the collection in low-level code.
116-
For `GPUArrays`, will return `ArrayInterface.GPU()`.
117-
For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`.
118-
For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`.
119-
Otherwise, returns `nothing`.
120-
"""
121-
device(A) = device(typeof(A))
122-
device(::Type) = nothing
123-
device(::Type{<:Tuple}) = CPUTuple()
124-
device(::Type{T}) where {T<:Array} = CPUPointer()
125-
device(::Type{T}) where {T<:AbstractArray} = _device(has_parent(T), T)
126-
function _device(::True, ::Type{T}) where {T}
127-
if defines_strides(T)
128-
return device(parent_type(T))
129-
else
130-
return _not_pointer(device(parent_type(T)))
131-
end
132-
end
133-
_not_pointer(::CPUPointer) = CPUIndex()
134-
_not_pointer(x) = x
135-
_device(::False, ::Type{T}) where {T<:DenseArray} = CPUPointer()
136-
_device(::False, ::Type{T}) where {T} = CPUIndex()
137-
138113
"""
139114
is_lazy_conjugate(::AbstractArray) -> Bool
140115

src/stridelayout.jl

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,4 @@
11

2-
"""
3-
defines_strides(::Type{T}) -> Bool
4-
5-
Is strides(::T) defined? It is assumed that types returning `true` also return a valid
6-
pointer on `pointer(::T)`.
7-
"""
8-
defines_strides(x) = defines_strides(typeof(x))
9-
_defines_strides(::Type{T}, ::Type{T}) where {T} = false
10-
_defines_strides(::Type{P}, ::Type{T}) where {P,T} = defines_strides(P)
11-
defines_strides(::Type{T}) where {T} = _defines_strides(parent_type(T), T)
12-
defines_strides(@nospecialize T::Type{<:StridedArray}) = true
13-
defines_strides(@nospecialize T::Type{<:BitArray}) = true
14-
@inline function defines_strides(@nospecialize T::Type{<:SubArray})
15-
stride_preserving_index(fieldtype(T, :indices))
16-
end
17-
#=
18-
stride_preserving_index(::Type{T}) -> StaticBool
19-
20-
Returns `True` if strides between each element can still be derived when indexing with an
21-
instance of type `T`.
22-
=#
23-
stride_preserving_index(@nospecialize T::Type{<:AbstractRange}) = true
24-
stride_preserving_index(@nospecialize T::Type{<:Number}) = true
25-
@inline function stride_preserving_index(@nospecialize T::Type{<:Tuple})
26-
all(map_tuple_type(stride_preserving_index, T))
27-
end
28-
stride_preserving_index(@nospecialize T::Type) = false
29-
302
"""
313
known_offsets(::Type{T}) -> Tuple
324
known_offsets(::Type{T}, dim) -> Union{Int,Nothing}

0 commit comments

Comments
 (0)