Skip to content

Commit 6a943d2

Browse files
Tests pass
1 parent 44ac1fb commit 6a943d2

34 files changed

+634
-570
lines changed

Project.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,28 @@ uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
33
version = "6.0.25"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
8+
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
811
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
912
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1013
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
1114

1215
[compat]
1316
Compat = "4"
1417
SnoopPrecompile = "1"
18+
Requires = "1"
1519
julia = "1.6"
1620

1721
[extensions]
1822
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
1923
ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices"
2024
ArrayInterfaceCUDAExt = "CUDA"
2125
ArrayInterfaceGPUArraysExt = "GPUArraysCore"
22-
ArrayInterfaceOffsetArraysExt = "OffsetArrays"
23-
ArrayInterfaceStaticArraysExt = "StaticArrays"
26+
ArrayInterfaceOffsetArraysExt = ["OffsetArrays","Static"]
27+
ArrayInterfaceStaticArraysExt = ["StaticArrays","Static"]
2428
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
2529
ArrayInterfaceStaticExt = "Static"
2630
ArrayInterfaceTrackerExt = "Tracker"
@@ -33,7 +37,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3337
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
3438
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3539
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
36-
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
3740
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3841
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3942
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
@@ -45,7 +48,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4548
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4649

4750
[targets]
48-
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "IfElse", "Random", "SparseArrays", "SuiteSparse", "Static", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "OffsetArrays", "StaticArrays", "StaticArraysCore", "Tracker"]
51+
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "Static", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "OffsetArrays", "StaticArrays", "StaticArraysCore", "Tracker"]
4952

5053
[weakdeps]
5154
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"

ext/ArrayInterfaceBandedMatrices.jl renamed to ext/ArrayInterfaceBandedMatricesExt.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
module ArrayInterfaceBandedMatricesExt
22

33
using ArrayInterface
4+
using ArrayInterface: BandedMatrixIndex
45
using BandedMatrices
56

6-
struct BandedMatrixIndex <: ArrayInterface.MatrixIndex
7-
count::Int
8-
rowsize::Int
9-
colsize::Int
10-
bandinds::Array{Int,1}
11-
bandsizes::Array{Int,1}
12-
isrow::Bool
13-
end
14-
157
Base.firstindex(i::BandedMatrixIndex) = 1
168
Base.lastindex(i::BandedMatrixIndex) = i.count
179
Base.length(i::BandedMatrixIndex) = lastindex(i)

ext/ArrayInterfaceBlockBandedMatrices.jl renamed to ext/ArrayInterfaceBlockBandedMatricesExt.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ArrayInterfaceBlockBandedMatricesExt
22

33
using ArrayInterface
4-
using ArrayInterfaceBandedMatrices
4+
using ArrayInterface: BandedMatrixIndex
55
using BlockBandedMatrices
66
using BlockBandedMatrices.BlockArrays
77

@@ -15,8 +15,8 @@ Base.firstindex(i::BlockBandedMatrixIndex) = 1
1515
Base.lastindex(i::BlockBandedMatrixIndex) = i.count
1616
Base.length(i::BlockBandedMatrixIndex) = lastindex(i)
1717
function BlockBandedMatrixIndex(nrowblock, ncolblock, rowsizes, colsizes, l, u)
18-
blockrowind = ArrayInterfaceBandedMatrices.BandedMatrixIndex(nrowblock, ncolblock, l, u, true)
19-
blockcolind = ArrayInterfaceBandedMatrices.BandedMatrixIndex(nrowblock, ncolblock, l, u, false)
18+
blockrowind = BandedMatrixIndex(nrowblock, ncolblock, l, u, true)
19+
blockcolind = BandedMatrixIndex(nrowblock, ncolblock, l, u, false)
2020
sortedinds = sort(
2121
[(blockrowind[i], blockcolind[i]) for i = 1:length(blockrowind)],
2222
by=x -> x[1],
@@ -100,7 +100,7 @@ struct BandedBlockBandedMatrixIndex <: ArrayInterface.MatrixIndex
100100
count::Int
101101
refinds::Array{Int,1}
102102
refcoords::Array{Int,1}# storing col or row inds at ref points
103-
reflocalinds::Array{ArrayInterfaceBandedMatrices.BandedMatrixIndex,1}
103+
reflocalinds::Array{BandedMatrixIndex,1}
104104
isrow::Bool
105105
end
106106
Base.firstindex(i::BandedBlockBandedMatrixIndex) = 1
@@ -128,8 +128,8 @@ function BandedBlockBandedMatrixIndex(
128128
lambda,
129129
mu,
130130
)
131-
blockrowind = ArrayInterfaceBandedMatrices.BandedMatrixIndex(nrowblock, ncolblock, l, u, true)
132-
blockcolind = ArrayInterfaceBandedMatrices.BandedMatrixIndex(nrowblock, ncolblock, l, u, false)
131+
blockrowind = BandedMatrixIndex(nrowblock, ncolblock, l, u, true)
132+
blockcolind = BandedMatrixIndex(nrowblock, ncolblock, l, u, false)
133133
sortedinds = sort(
134134
[(blockrowind[i], blockcolind[i]) for i = 1:length(blockrowind)],
135135
by=x -> x[1],
@@ -143,14 +143,14 @@ function BandedBlockBandedMatrixIndex(
143143
refinds = Array{Int,1}()
144144
refrowcoords = Array{Int,1}()
145145
refcolcoords = Array{Int,1}()
146-
reflocalrowinds = Array{ArrayInterfaceBandedMatrices.BandedMatrixIndex,1}()
147-
reflocalcolinds = Array{ArrayInterfaceBandedMatrices.BandedMatrixIndex,1}()
146+
reflocalrowinds = Array{BandedMatrixIndex,1}()
147+
reflocalcolinds = Array{BandedMatrixIndex,1}()
148148
for ind in sortedinds
149149
rowind, colind = ind
150150
localrowind =
151-
ArrayInterfaceBandedMatrices.BandedMatrixIndex(rowsizes[rowind], colsizes[colind], lambda, mu, true)
151+
BandedMatrixIndex(rowsizes[rowind], colsizes[colind], lambda, mu, true)
152152
localcolind =
153-
ArrayInterfaceBandedMatrices.BandedMatrixIndex(rowsizes[rowind], colsizes[colind], lambda, mu, false)
153+
BandedMatrixIndex(rowsizes[rowind], colsizes[colind], lambda, mu, false)
154154
push!(refinds, currenti)
155155
push!(refrowcoords, rowheights[rowind])
156156
push!(refcolcoords, colwidths[colind])
File renamed without changes.
File renamed without changes.

ext/ArrayInterfaceOffsetArrays.jl renamed to ext/ArrayInterfaceOffsetArraysExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function ArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray
3131
ArrayInterface.parent_type(T)
3232
)
3333
end
34-
ArrayInterface.strides(A::OffsetArray) = ArrayInterface.strides(parent(A))
34+
ArrayInterface.static_strides(A::OffsetArray) = ArrayInterface.static_strides(parent(A))
3535
function ArrayInterface.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
3636
ntuple(identity -> nothing, Val(ndims(A)))
3737
end
@@ -42,12 +42,12 @@ end
4242
d = ArrayInterface.to_dims(A, dim)
4343
ArrayInterface.offsets(parent(A), d) + relative_offsets(A, d)
4444
end
45-
@inline function ArrayInterface.axes(A::OffsetArrays.OffsetArray)
46-
map(OffsetArrays.IdOffsetRange, ArrayInterface.axes(parent(A)), relative_offsets(A))
45+
@inline function ArrayInterface.static_axes(A::OffsetArrays.OffsetArray)
46+
map(OffsetArrays.IdOffsetRange, ArrayInterface.static_axes(parent(A)), relative_offsets(A))
4747
end
48-
@inline function ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim)
48+
@inline function ArrayInterface.static_axes(A::OffsetArrays.OffsetArray, dim)
4949
d = ArrayInterface.to_dims(A, dim)
50-
OffsetArrays.IdOffsetRange(ArrayInterface.axes(parent(A), d), relative_offsets(A, d))
50+
OffsetArrays.IdOffsetRange(ArrayInterface.static_axes(parent(A), d), relative_offsets(A, d))
5151
end
5252
function ArrayInterface.stride_rank(T::Type{<:OffsetArray})
5353
ArrayInterface.stride_rank(ArrayInterface.parent_type(T))

ext/ArrayInterfaceStaticArrays.jl renamed to ext/ArrayInterfaceStaticArraysExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using ArrayInterface
55
using LinearAlgebra
66
using StaticArrays
77
using Static
8-
import ArrayInterfaceStaticArraysCore
8+
using Static: StaticInt
99

1010
const CanonicalInt = Union{Int,StaticInt}
1111

@@ -20,6 +20,7 @@ function ArrayInterface.known_length(::Type{A}) where {A<:StaticArrays.StaticArr
2020
ArrayInterface.known_length(StaticArrays.Length(A))
2121
end
2222

23+
@inline ArrayInterface.static_length(x::StaticArrays.StaticArray) = Static.maybe_static(ArrayInterface.known_length, Base.length, x)
2324
ArrayInterface.device(::Type{<:StaticArrays.MArray}) = ArrayInterface.CPUPointer()
2425
ArrayInterface.device(::Type{<:StaticArrays.SArray}) = ArrayInterface.CPUTuple()
2526
ArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
@@ -36,15 +37,15 @@ ArrayInterface.defines_strides(::Type{<:StaticArrays.MArray}) = true
3637
@generated function ArrayInterface.axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
3738
Tuple{[StaticArrays.SOneTo{s} for s in S.parameters]...}
3839
end
39-
@generated function ArrayInterface.size(A::StaticArrays.StaticArray{S}) where {S}
40+
@generated function ArrayInterface.static_size(A::StaticArrays.StaticArray{S}) where {S}
4041
t = Expr(:tuple)
4142
Sp = S.parameters
4243
for n = 1:length(Sp)
4344
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n])))
4445
end
4546
return t
4647
end
47-
@generated function ArrayInterface.strides(A::StaticArrays.StaticArray{S}) where {S}
48+
@generated function ArrayInterface.static_strides(A::StaticArrays.StaticArray{S}) where {S}
4849
t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1)))
4950
Sp = S.parameters
5051
x = 1
@@ -54,7 +55,7 @@ end
5455
return t
5556
end
5657
if StaticArrays.SizedArray{Tuple{8,8},Float64,2,2} isa UnionAll
57-
@inline ArrayInterface.strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = ArrayInterface.strides(B.data)
58+
@inline ArrayInterface.static_strides(B::StaticArrays.SizedArray{S,T,M,N,A}) where {S,T,M,N,A<:SubArray} = ArrayInterface.static_strides(B.data)
5859
ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N,A}}) where {S,T,M,N,A} = A
5960
else
6061
ArrayInterface.parent_type(::Type{<:StaticArrays.SizedArray{S,T,M,N}}) where {S,T,M,N} = Array{T,N}

ext/ArrayInterfaceStaticExt.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import ArrayInterface: allowed_getindex, allowed_setindex!, aos_to_soa, buffer,
1616

1717
# ArrayIndex subtypes and methods
1818
import ArrayInterface: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex,
19-
TridiagonalIndex
19+
TridiagonalIndex, StrideIndex
2020
# managing immutables
2121
import ArrayInterface: ismutable, can_change_size, can_setindex
2222
# constants
@@ -27,6 +27,11 @@ import ArrayInterface: AbstractDevice, AbstractCPU, CPUPointer, CPUTuple, CheckP
2727

2828
import ArrayInterface: known_first, known_step, known_last
2929

30+
import ArrayInterface: offsets, axes_types, offset1, indices, known_offsets, stride_rank, strides, dense_dims, contiguous_axis, known_length, contiguous_batch_size,
31+
contiguous_axis_indicator, is_column_major, _all_dense, AbstractArray2, dimnames, known_dimnames, known_offset1, known_strides, LazyAxis, lazy_axes, BroadcastAxis, broadcast_axis,
32+
to_axes, find_all_dimnames, to_dims, known_size, deleteat, insert, static_size, is_lazy_conjugate, static_stride, static_strides, has_dimnames, unsafe_reconstruct, static_to_indices,
33+
to_index, unsafe_setindex!, unsafe_getindex, static_axes, to_axis, is_dense, static_getindex
34+
3035
using Static
3136
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
3237
permute, invariant_permutation, field_type, reduce_tup, find_first_eq,
@@ -52,24 +57,22 @@ _sub1(@nospecialize x) = x - oneunit(x)
5257
Tuple{X.parameters..., Y.parameters...}
5358
end
5459

55-
abstract type AbstractArray2{T, N} <: AbstractArray{T, N} end
56-
57-
Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
58-
Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.size(A, dim))
60+
Base.size(A::AbstractArray2) = map(Int, ArrayInterface.static_size(A))
61+
Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.static_size(A, dim))
5962

6063
function Base.axes(A::AbstractArray2)
61-
is_forwarding_wrapper(A) && return ArrayInterface.axes(parent(A))
64+
is_forwarding_wrapper(A) && return ArrayInterface.static_axes(parent(A))
6265
throw(ArgumentError("Subtypes of `AbstractArray2` must define an axes method"))
6366
end
6467
function Base.axes(A::AbstractArray2, dim::Union{Symbol, StaticSymbol})
65-
axes(A, to_dims(A, dim))
68+
static_axes(A, to_dims(A, dim))
6669
end
6770

6871
function Base.strides(A::AbstractArray2)
69-
defines_strides(A) && return map(Int, ArrayInterface.strides(A))
72+
defines_strides(A) && return map(Int, ArrayInterface.static_strides(A))
7073
throw(MethodError(Base.strides, (A,)))
7174
end
72-
Base.strides(A::AbstractArray2, dim) = Int(ArrayInterface.strides(A, dim))
75+
Base.strides(A::AbstractArray2, dim) = Int(ArrayInterface.static_strides(A, dim))
7376

7477
function Base.IndexStyle(::Type{T}) where {T <: AbstractArray2}
7578
is_forwarding_wrapper(T) ? IndexStyle(parent_type(T)) : IndexCartesian()
@@ -78,14 +81,14 @@ end
7881
function Base.length(A::AbstractArray2)
7982
len = known_length(A)
8083
if len === nothing
81-
return Int(prod(size(A)))
84+
return Int(prod(static_size(A)))
8285
else
8386
return Int(len)
8487
end
8588
end
8689

87-
@propagate_inbounds Base.getindex(A::AbstractArray2, args...) = getindex(A, args...)
88-
@propagate_inbounds Base.getindex(A::AbstractArray2; kwargs...) = getindex(A; kwargs...)
90+
@propagate_inbounds Base.getindex(A::AbstractArray2, args...) = static_getindex(A, args...)
91+
@propagate_inbounds Base.getindex(A::AbstractArray2; kwargs...) = static_getindex(A; kwargs...)
8992

9093
@propagate_inbounds function Base.setindex!(A::AbstractArray2, val, args...)
9194
return setindex!(A, val, args...)
@@ -102,7 +105,7 @@ end
102105
@inbounds(CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i])
103106
end
104107
@inline function _to_linear(a, i::Tuple{IntType, Vararg{IntType}})
105-
_strides2int(offsets(a), size_to_strides(size(a), static(1)), i) + static(1)
108+
_strides2int(offsets(a), size_to_strides(static_size(a), static(1)), i) + static(1)
106109
end
107110

108111
"""
@@ -167,7 +170,7 @@ Returns a new instance of `collection` with `item` inserted into at the given `i
167170
"""
168171
Base.@propagate_inbounds function insert(collection, index, item)
169172
@boundscheck checkbounds(collection, index)
170-
ret = similar(collection, length(collection) + 1)
173+
ret = similar(collection, static_length(collection) + 1)
171174
@inbounds for i in firstindex(ret):(index - 1)
172175
ret[i] = collection[i]
173176
end
@@ -213,7 +216,7 @@ Base.@propagate_inbounds function deleteat(collection::Tuple{Vararg{Any, N}},
213216
end
214217

215218
function unsafe_deleteat(src::AbstractVector, index)
216-
dst = similar(src, length(src) - 1)
219+
dst = similar(src, static_length(src) - 1)
217220
@inbounds for i in indices(dst)
218221
if i < index
219222
dst[i] = src[i]
@@ -225,7 +228,7 @@ function unsafe_deleteat(src::AbstractVector, index)
225228
end
226229

227230
@inline function unsafe_deleteat(src::AbstractVector, inds::AbstractVector)
228-
dst = similar(src, length(src) - length(inds))
231+
dst = similar(src, static_length(src) - static_length(inds))
229232
dst_index = firstindex(dst)
230233
@inbounds for src_index in indices(src)
231234
if !in(src_index, inds)
@@ -237,9 +240,9 @@ end
237240
end
238241

239242
@inline function unsafe_deleteat(src::Tuple, inds::AbstractVector)
240-
dst = Vector{eltype(src)}(undef, length(src) - length(inds))
243+
dst = Vector{eltype(src)}(undef, static_length(src) - static_length(inds))
241244
dst_index = firstindex(dst)
242-
@inbounds for src_index in static(1):length(src)
245+
@inbounds for src_index in static(1):static_length(src)
243246
if !in(src_index, inds)
244247
dst[dst_index] = src[src_index]
245248
dst_index += one(dst_index)
@@ -253,7 +256,7 @@ end
253256
@inline function unsafe_deleteat(x::Tuple, i)
254257
if i === one(i)
255258
return tail(x)
256-
elseif i == length(x)
259+
elseif i == static_length(x)
257260
return Base.front(x)
258261
else
259262
return (first(x), unsafe_deleteat(tail(x), i - one(i))...)

ext/ArrayInterfaceTracker.jl renamed to ext/ArrayInterfaceTrackerExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module ArrayInterfaceTracker
1+
module ArrayInterfaceTrackerExt
22

33
using ArrayInterface
44
using Tracker

0 commit comments

Comments
 (0)