Skip to content

Commit c044265

Browse files
authored
Move SparseArrays to an extension (#158)
* Move SparseArrays to an extension * Don't import Base explicitly * Add explicit tests * Bump version to v1.3.0
1 parent a2c0e7d commit c044265

File tree

5 files changed

+61
-7
lines changed

5 files changed

+61
-7
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,22 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1010

11+
[weakdeps]
12+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
13+
14+
[extensions]
15+
ArrayLayoutsSparseArraysExt = "SparseArrays"
16+
1117
[compat]
1218
FillArrays = "1.2.1"
1319
julia = "1.6"
1420

1521
[extras]
1622
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1723
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
24+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1825
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1926
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2027

2128
[targets]
22-
test = ["Base64", "Random", "StableRNGs", "Test"]
29+
test = ["Base64", "Random", "StableRNGs", "SparseArrays", "Test"]

ext/ArrayLayoutsSparseArraysExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module ArrayLayoutsSparseArraysExt
2+
3+
using ArrayLayouts
4+
using ArrayLayouts: _copyto!
5+
using SparseArrays
6+
import LinearAlgebra
7+
8+
import Base: copyto!
9+
10+
# ambiguity from sparsematrix.jl
11+
copyto!(dest::LayoutMatrix, src::SparseArrays.AbstractSparseMatrixCSC) =
12+
_copyto!(dest, src)
13+
14+
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SparseArrays.AbstractSparseMatrixCSC) =
15+
_copyto!(dest, src)
16+
17+
@inline LinearAlgebra.dot(a::LayoutArray{<:Number}, b::SparseArrays.SparseVectorUnion{<:Number}) =
18+
ArrayLayouts.dot(a,b)
19+
20+
@inline LinearAlgebra.dot(a::SparseArrays.SparseVectorUnion{<:Number}, b::LayoutArray{<:Number}) =
21+
ArrayLayouts.dot(a,b)
22+
23+
end

src/ArrayLayouts.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ArrayLayouts
22
using Base: _typed_hcat
3-
using Base, Base.Broadcast, LinearAlgebra, FillArrays, SparseArrays
3+
using Base.Broadcast, LinearAlgebra, FillArrays
44
using LinearAlgebra.BLAS
55

66
using Base: AbstractCartesianIndex, OneTo, oneto, RangeIndex, ReinterpretArray, ReshapedArray,
@@ -275,9 +275,6 @@ copyto!(dest::AbstractMatrix, src::AdjOrTrans{<:Any,<:LayoutArray}) = _copyto!(d
275275
copyto!(dest::SubArray{<:Any,2,<:LayoutArray}, src::AdjOrTrans{<:Any,<:LayoutArray}) = _copyto!(dest, src)
276276
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SubArray{<:Any,2,<:AdjOrTrans{<:Any,<:LayoutArray}}) = _copyto!(dest, src)
277277
copyto!(dest::AbstractMatrix, src::SubArray{<:Any,2,<:AdjOrTrans{<:Any,<:LayoutArray}}) = _copyto!(dest, src)
278-
# ambiguity from sparsematrix.jl
279-
copyto!(dest::LayoutMatrix, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
280-
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
281278
if isdefined(LinearAlgebra, :copymutable_oftype)
282279
LinearAlgebra.copymutable_oftype(A::Union{LayoutArray,Symmetric{<:Any,<:LayoutMatrix},Hermitian{<:Any,<:LayoutMatrix},
283280
UpperOrLowerTriangular{<:Any,<:LayoutMatrix},
@@ -417,4 +414,9 @@ Base.typed_vcat(::Type{T}, A::LayoutVecOrMats, B::LayoutVecOrMats, C::AbstractVe
417414
Base.typed_hcat(::Type{T}, A::LayoutVecOrMats, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_hcat(T, A, B, C...)
418415
Base.typed_vcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_vcat(T, A, B, C...)
419416
Base.typed_hcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_hcat(T, A, B, C...)
417+
418+
if !isdefined(Base, :get_extension)
419+
include("../ext/ArrayLayoutsSparseArraysExt.jl")
420+
end
421+
420422
end # module

src/mul.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,6 @@ dot(a, b) = materialize(Dot(a, b))
362362
@inline LinearAlgebra.dot(a::AbstractArray, b::LayoutArray) = dot(a,b)
363363
@inline LinearAlgebra.dot(a::LayoutVector, b::AbstractFill{<:Any,1}) = FillArrays._fill_dot_rev(a,b)
364364
@inline LinearAlgebra.dot(a::AbstractFill{<:Any,1}, b::LayoutVector) = FillArrays._fill_dot(a,b)
365-
@inline LinearAlgebra.dot(a::LayoutArray{<:Number}, b::SparseArrays.SparseVectorUnion{<:Number}) = dot(a,b)
366-
@inline LinearAlgebra.dot(a::SparseArrays.SparseVectorUnion{<:Number}, b::LayoutArray{<:Number}) = dot(a,b)
367365

368366
@inline LinearAlgebra.dot(a::SubArray{<:Any,N,<:LayoutArray}, b::AbstractArray) where N = dot(a,b)
369367
@inline LinearAlgebra.dot(a::SubArray{<:Any,N,<:LayoutArray}, b::LayoutArray) where N = dot(a,b)

test/test_layoutarray.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,29 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
480480
end
481481
end
482482

483+
@testset "sparse" begin
484+
@testset "MyVector" begin
485+
V = MyVector([1:4;])
486+
V2 = MyVector(2*[1:4;])
487+
S = 2*sparse(V)
488+
copyto!(V, S)
489+
@test S == V2
490+
V = MyVector([1:4;])
491+
copyto!(view(V, :), S)
492+
@test S == V2
493+
end
494+
@testset "MyMatrix" begin
495+
M = MyMatrix(reshape([1:4;], 2, 2))
496+
M2 = MyMatrix(reshape(2*[1:4;], 2, 2))
497+
S = 2*sparse(M)
498+
copyto!(M, S)
499+
@test S == M2
500+
M = MyMatrix(reshape([1:4;], 2, 2))
501+
copyto!(view(M, :, :), S)
502+
@test S == M2
503+
end
504+
end
505+
483506
@testset "mul! with subarrays" begin
484507
A = MyMatrix(randn(3,3))
485508
V = view(A, 1:3, 1:3)
@@ -500,6 +523,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
500523
@test mul!(copy(B), A, V, 2.0, 3.0) 2A * A + 3B
501524
@test mul!(MyMatrix(copy(B)), A, V, 2.0, 3.0) 2A * A + 3B
502525
@test mul!(copy(x), V, x, 2.0, 3.0) 2A * x + 3x
526+
503527
end
504528
end
505529

0 commit comments

Comments
 (0)