Skip to content

Commit 8ac6d4e

Browse files
use ArrayInterface
1 parent 081b14e commit 8ac6d4e

File tree

3 files changed

+52
-52
lines changed

3 files changed

+52
-52
lines changed

lib/ArrayInterfaceCUDA/src/ArrayInterfaceCUDA.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@ using CUDA
66

77
const CanonicalInt = Union{Int,StaticInt}
88

9-
ArrayInterfaceCore.fast_scalar_indexing(::Type{<:CUDA.CuArray}) = false
10-
@inline ArrayInterfaceCore.allowed_getindex(x::CUDA.CuArray, i...) = CUDA.@allowscalar(x[i...])
11-
@inline ArrayInterfaceCore.allowed_setindex!(x::CUDA.CuArray, v, i...) = (CUDA.@allowscalar(x[i...] = v))
9+
ArrayInterface.fast_scalar_indexing(::Type{<:CUDA.CuArray}) = false
10+
@inline ArrayInterface.allowed_getindex(x::CUDA.CuArray, i...) = CUDA.@allowscalar(x[i...])
11+
@inline ArrayInterface.allowed_setindex!(x::CUDA.CuArray, v, i...) = (CUDA.@allowscalar(x[i...] = v))
1212

1313
function Base.setindex(x::CUDA.CuArray, v, i::Int)
1414
_x = copy(x)
15-
ArrayInterfaceCore.allowed_setindex!(_x, v, i)
15+
ArrayInterface.allowed_setindex!(_x, v, i)
1616
return _x
1717
end
1818

19-
function ArrayInterfaceCore.restructure(x::CUDA.CuArray, y)
20-
reshape(Adapt.adapt(ArrayInterfaceCore.parameterless_type(x), y), Base.size(x)...)
19+
function ArrayInterface.restructure(x::CUDA.CuArray, y)
20+
reshape(Adapt.adapt(ArrayInterface.parameterless_type(x), y), Base.size(x)...)
2121
end
2222

23-
ArrayInterfaceCore.device(::Type{<:CUDA.CuArray}) = ArrayInterfaceCore.GPU()
23+
ArrayInterface.device(::Type{<:CUDA.CuArray}) = ArrayInterface.GPU()
2424

25-
function ArrayInterfaceCore.lu_instance(A::CuMatrix{T}) where {T}
25+
function ArrayInterface.lu_instance(A::CuMatrix{T}) where {T}
2626
CUDA.CUSOLVER.CuQR(similar(A, 0, 0), similar(A, 0))
2727
end
2828

lib/ArrayInterfaceOffsetArrays/src/ArrayInterfaceOffsetArrays.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,29 @@ function relative_offsets(A::OffsetArrays.OffsetArray, dim::Int)
2222
return getfield(relative_offsets(A), dim)
2323
end
2424
end
25-
ArrayInterfaceCore.parent_type(::Type{<:OffsetArrays.OffsetArray{T,N,A}}) where {T,N,A} = A
25+
ArrayInterface.parent_type(::Type{<:OffsetArrays.OffsetArray{T,N,A}}) where {T,N,A} = A
2626
function _offset_axis_type(::Type{T}, dim::StaticInt{D}) where {T,D}
27-
OffsetArrays.IdOffsetRange{Int,ArrayInterfaceCore.axes_types(T, dim)}
27+
OffsetArrays.IdOffsetRange{Int,ArrayInterface.axes_types(T, dim)}
2828
end
29-
function ArrayInterfaceCore.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
30-
Static.eachop_tuple(_offset_axis_type, Static.nstatic(Val(ndims(T))), ArrayInterfaceCore.parent_type(T))
29+
function ArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
30+
Static.eachop_tuple(_offset_axis_type, Static.nstatic(Val(ndims(T))), ArrayInterface.parent_type(T))
3131
end
32-
function ArrayInterfaceCore.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
32+
function ArrayInterface.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
3333
ntuple(identity -> nothing, Val(ndims(A)))
3434
end
35-
function ArrayInterfaceCore.offsets(A::OffsetArrays.OffsetArray)
36-
map(+, ArrayInterfaceCore.offsets(parent(A)), relative_offsets(A))
35+
function ArrayInterface.offsets(A::OffsetArrays.OffsetArray)
36+
map(+, ArrayInterface.offsets(parent(A)), relative_offsets(A))
3737
end
38-
@inline function ArrayInterfaceCore.offsets(A::OffsetArrays.OffsetArray, dim)
39-
d = ArrayInterfaceCore.to_dims(A, dim)
40-
ArrayInterfaceCore.offsets(parent(A), d) + relative_offsets(A, d)
38+
@inline function ArrayInterface.offsets(A::OffsetArrays.OffsetArray, dim)
39+
d = ArrayInterface.to_dims(A, dim)
40+
ArrayInterface.offsets(parent(A), d) + relative_offsets(A, d)
4141
end
42-
@inline function ArrayInterfaceCore.axes(A::OffsetArrays.OffsetArray)
43-
map(OffsetArrays.IdOffsetRange, ArrayInterfaceCore.axes(parent(A)), relative_offsets(A))
42+
@inline function ArrayInterface.axes(A::OffsetArrays.OffsetArray)
43+
map(OffsetArrays.IdOffsetRange, ArrayInterface.axes(parent(A)), relative_offsets(A))
4444
end
45-
@inline function ArrayInterfaceCore.axes(A::OffsetArrays.OffsetArray, dim)
46-
d = ArrayInterfaceCore.to_dims(A, dim)
47-
OffsetArrays.IdOffsetRange(ArrayInterfaceCore.axes(parent(A), d), relative_offsets(A, d))
45+
@inline function ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim)
46+
d = ArrayInterface.to_dims(A, dim)
47+
OffsetArrays.IdOffsetRange(ArrayInterface.axes(parent(A), d), relative_offsets(A, d))
4848
end
4949

5050
end # module

lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,57 +8,57 @@ using Static
88

99
const CanonicalInt = Union{Int,StaticInt}
1010

11-
ArrayInterfaceCore.ismutable(::Type{<:StaticArrays.StaticArray}) = false
12-
ArrayInterfaceCore.ismutable(::Type{<:StaticArrays.MArray}) = true
13-
ArrayInterfaceCore.ismutable(::Type{<:StaticArrays.SizedArray}) = true
11+
ArrayInterface.ismutable(::Type{<:StaticArrays.StaticArray}) = false
12+
ArrayInterface.ismutable(::Type{<:StaticArrays.MArray}) = true
13+
ArrayInterface.ismutable(::Type{<:StaticArrays.SizedArray}) = true
1414

15-
ArrayInterfaceCore.can_setindex(::Type{<:StaticArrays.StaticArray}) = false
16-
ArrayInterfaceCore.buffer(A::Union{StaticArrays.SArray,StaticArrays.MArray}) = getfield(A, :data)
15+
ArrayInterface.can_setindex(::Type{<:StaticArrays.StaticArray}) = false
16+
ArrayInterface.buffer(A::Union{StaticArrays.SArray,StaticArrays.MArray}) = getfield(A, :data)
1717

18-
function ArrayInterfaceCore.lu_instance(_A::StaticArrays.StaticMatrix{N,N}) where {N}
18+
function ArrayInterface.lu_instance(_A::StaticArrays.StaticMatrix{N,N}) where {N}
1919
A = StaticArrays.SArray(_A)
2020
L = LowerTriangular(A)
2121
U = UpperTriangular(A)
2222
p = StaticArrays.SVector{N,Int}(1:N)
2323
return StaticArrays.LU(L, U, p)
2424
end
2525

26-
function ArrayInterfaceCore.restructure(x::StaticArrays.SArray, y::StaticArrays.SArray)
26+
function ArrayInterface.restructure(x::StaticArrays.SArray, y::StaticArrays.SArray)
2727
reshape(y, StaticArrays.Size(x))
2828
end
29-
ArrayInterfaceCore.restructure(x::StaticArrays.SArray{S}, y) where {S} = StaticArrays.SArray{S}(y)
29+
ArrayInterface.restructure(x::StaticArrays.SArray{S}, y) where {S} = StaticArrays.SArray{S}(y)
3030

31-
ArrayInterfaceCore.known_first(::Type{<:StaticArrays.SOneTo}) = 1
32-
ArrayInterfaceCore.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
33-
ArrayInterfaceCore.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
34-
ArrayInterfaceCore.known_length(::Type{StaticArrays.Length{L}}) where {L} = L
35-
function ArrayInterfaceCore.known_length(::Type{A}) where {A<:StaticArrays.StaticArray}
36-
ArrayInterfaceCore.known_length(StaticArrays.Length(A))
31+
ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
32+
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
33+
ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
34+
ArrayInterface.known_length(::Type{StaticArrays.Length{L}}) where {L} = L
35+
function ArrayInterface.known_length(::Type{A}) where {A<:StaticArrays.StaticArray}
36+
ArrayInterface.known_length(StaticArrays.Length(A))
3737
end
3838

39-
ArrayInterfaceCore.device(::Type{<:StaticArrays.MArray}) = ArrayInterfaceCore.CPUPointer()
40-
ArrayInterfaceCore.device(::Type{<:StaticArrays.SArray}) = ArrayInterfaceCore.CPUTuple()
41-
ArrayInterfaceCore.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
42-
ArrayInterfaceCore.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
43-
ArrayInterfaceCore.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}} = Static.nstatic(Val(N))
44-
function ArrayInterfaceCore.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N}
45-
ArrayInterfaceCore._all_dense(Val(N))
39+
ArrayInterface.device(::Type{<:StaticArrays.MArray}) = ArrayInterface.CPUPointer()
40+
ArrayInterface.device(::Type{<:StaticArrays.SArray}) = ArrayInterface.CPUTuple()
41+
ArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
42+
ArrayInterface.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
43+
ArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}} = Static.nstatic(Val(N))
44+
function ArrayInterface.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N}
45+
ArrayInterface._all_dense(Val(N))
4646
end
47-
ArrayInterfaceCore.defines_strides(::Type{<:StaticArrays.SArray}) = true
48-
ArrayInterfaceCore.defines_strides(::Type{<:StaticArrays.MArray}) = true
47+
ArrayInterface.defines_strides(::Type{<:StaticArrays.SArray}) = true
48+
ArrayInterface.defines_strides(::Type{<:StaticArrays.MArray}) = true
4949

50-
@generated function ArrayInterfaceCore.axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
50+
@generated function ArrayInterface.axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
5151
Tuple{[StaticArrays.SOneTo{s} for s in S.parameters]...}
5252
end
53-
@generated function ArrayInterfaceCore.size(A::StaticArrays.StaticArray{S}) where {S}
53+
@generated function ArrayInterface.size(A::StaticArrays.StaticArray{S}) where {S}
5454
t = Expr(:tuple)
5555
Sp = S.parameters
5656
for n = 1:length(Sp)
5757
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n])))
5858
end
5959
return t
6060
end
61-
@generated function ArrayInterfaceCore.strides(A::StaticArrays.StaticArray{S}) where {S}
61+
@generated function ArrayInterface.strides(A::StaticArrays.StaticArray{S}) where {S}
6262
t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1)))
6363
Sp = S.parameters
6464
x = 1
@@ -68,10 +68,10 @@ end
6868
return t
6969
end
7070
if StaticArrays.SizedArray{Tuple{8,8},Float64,2,2} isa UnionAll
71-
@inline ArrayInterfaceCore.strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = ArrayInterfaceCore.strides(B.data)
72-
ArrayInterfaceCore.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A
71+
@inline ArrayInterface.strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = ArrayInterface.strides(B.data)
72+
ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A
7373
else
74-
ArrayInterfaceCore.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N}
74+
ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N}
7575
end
7676

7777
Adapt.adapt_storage(::Type{<:StaticArrays.SArray{S}}, xs::Array) where {S} = SArray{S}(xs)

0 commit comments

Comments
 (0)