Skip to content

Commit b0e4d48

Browse files
fast scalar indexing trait
1 parent d7aa0f0 commit b0e4d48

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "1.0.0"
3+
version = "1.1.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ Returns an array of for the sparsity colors of a matrix type `A`. Also includes
4444
an abstract type `ColoringAlgorithm` for `matrix_colors(A,alg::ColoringAlgorithm)`
4545
of non-structured matrices.
4646

47+
## fast_scalar_indexing(A)
48+
49+
A trait function for whether scalar indexing is fast on a given array type.
50+
4751
## List of things to add
4852

4953
- https://github.com/JuliaLang/julia/issues/22216

src/ArrayInterface.jl

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,23 @@ ismutable(::Type{<:Number}) = false
2424
Query whether a type can use `setindex!`
2525
"""
2626
can_setindex(x) = true
27+
can_setindex(x::AbstractArray) = can_setindex(typeof(x))
28+
29+
"""
30+
fast_scalar_indexing(x)
31+
32+
Query whether an array type has fast scalar indexing
33+
"""
34+
fast_scalar_indexing(x) = true
35+
fast_scalar_indexing(x::AbstractArray) = fast_scalar_indexing(typeof(x))
2736

2837
"""
2938
isstructured(x::DataType)
3039
3140
Query whether a type is a representation of a structured matrix
3241
"""
3342
isstructured(x) = false
43+
isstructured(x::AbstractArray) = isstructured(typeof(x))
3444
isstructured(::Symmetric) = true
3545
isstructured(::Hermitian) = true
3646
isstructured(::UpperTriangular) = true
@@ -45,13 +55,14 @@ isstructured(::Diagonal) = true
4555
4656
determine whether `findstructralnz` accepts the parameter `x`
4757
"""
48-
has_sparsestruct(x)=false
49-
has_sparsestruct(x::AbstractArray)=false
50-
has_sparsestruct(x::SparseMatrixCSC)=true
51-
has_sparsestruct(x::Diagonal)=true
52-
has_sparsestruct(x::Bidiagonal)=true
53-
has_sparsestruct(x::Tridiagonal)=true
54-
has_sparsestruct(x::SymTridiagonal)=true
58+
has_sparsestruct(x) = false
59+
has_sparsestruct(x::AbstractArray) = has_sparsestruct(typeof(x))
60+
has_sparsestruct(x::Type{<:AbstractArray}) = false
61+
has_sparsestruct(x::Type{<:SparseMatrixCSC}) = true
62+
has_sparsestruct(x::Type{<:Diagonal}) = true
63+
has_sparsestruct(x::Type{<:Bidiagonal}) = true
64+
has_sparsestruct(x::Type{<:Tridiagonal}) = true
65+
has_sparsestruct(x::Type{<:SymTridiagonal}) = true
5566

5667
"""
5768
findstructralnz(x::AbstractArray)
@@ -132,7 +143,8 @@ abstract type ColoringAlgorithm end
132143
colors of the matrix.
133144
"""
134145
fast_matrix_colors(A) = false
135-
fast_matrix_colors(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}) = true
146+
fast_matrix_colors(A::AbstractArray) = fast_matrix_colors(typeof(A))
147+
fast_matrix_colors(A::Type{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}}) = true
136148

137149
"""
138150
matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
@@ -170,16 +182,23 @@ function __init__()
170182

171183
@require LabelledArrays="2ee39098-c373-598a-b85f-a56591580800" begin
172184
ismutable(::Type{<:LabelledArrays.LArray{T,N,Syms}}) where {T,N,Syms} = ismutable(T)
185+
can_setindex(::Type{<:LabelledArrays.SLArray}) = false
186+
end
187+
188+
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
189+
ismutable(::Type{<:Tracker.TrackedArray}) = false
190+
can_setindex(::Type{<:Tracker.TrackedArray}) = false
191+
fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false
173192
end
174193

175-
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
176-
ismutable(::Type{<:Flux.Tracker.TrackedArray}) = false
177-
can_setindex(::Type{<:Flux.Tracker.TrackedArray}) = false
194+
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
195+
fast_scalar_indexing(::Type{<:CuArrays.CuArray}) = false
178196
end
179197

180198
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
181-
is_structured(::BandedMatrices.BandedMatrix) = true
182-
fast_matrix_colors(::BandedMatrices.BandedMatrix) = true
199+
is_structured(::Type{<:BandedMatrices.BandedMatrix}) = true
200+
fast_matrix_colors(::Type{<:BandedMatrices.BandedMatrix}) = true
201+
183202
function matrix_colors(A::BandedMatrices.BandedMatrix)
184203
u,l=bandwidths(A)
185204
width=u+l+1
@@ -189,10 +208,10 @@ function __init__()
189208
end
190209

191210
@require BlockBandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
192-
is_structured(::BandedMatrices.BlockBandedMatrix) = true
193-
is_structured(::BandedMatrices.BandedBlockBandedMatrix) = true
194-
fast_matrix_colors(::BlockBandedMatrices.BlockBandedMatrix) = true
195-
fast_matrix_colors(::BlockBandedMatrices.BandedBlockBandedMatrix) = true
211+
is_structured(::Type{<:BandedMatrices.BlockBandedMatrix}) = true
212+
is_structured(::Type{<:BandedMatrices.BandedBlockBandedMatrix}) = true
213+
fast_matrix_colors(::Type{<:BlockBandedMatrices.BlockBandedMatrix}) = true
214+
fast_matrix_colors(::Type{<:BlockBandedMatrices.BandedBlockBandedMatrix}) = true
196215

197216
function matrix_colors(A::BlockBandedMatrices.BlockBandedMatrix)
198217
l,u=blockbandwidths(A)

0 commit comments

Comments
 (0)