Skip to content

Commit 4778db2

Browse files
Merge pull request #4 from rasmushenningsson/mul_fallbacks
Better mul! fallbacks
2 parents 5733363 + 6538cb4 commit 4778db2

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/ThreadedSparseArrays.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ for f in [:rowvals, :nonzeros, :getcolptr]
5353
@eval SparseArrays.$(f)(A::ThreadedSparseMatrixCSC) = SparseArrays.$(f)(A.A)
5454
end
5555

56+
# For non-threaded implementations, fallback to sparse methods and not generic matmul.
57+
mul!(C::AbstractVector, A::ThreadedSparseMatrixCSC, B::AbstractVector, α::Number, β::Number) = mul!(C, A.A, B, α, β)
58+
mul!(C::AbstractMatrix, A::ThreadedSparseMatrixCSC, B::AbstractMatrix, α::Number, β::Number) = mul!(C, A.A, B, α, β)
59+
mul!(C::AbstractVector, A::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractVector, α::Number, β::Number) = mul!(C, adjoint(A.parent.A), B, α, β)
60+
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractMatrix, α::Number, β::Number) = mul!(C, adjoint(A.parent.A), B, α, β)
61+
mul!(C::AbstractVector, A::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractVector, α::Number, β::Number) = mul!(C, transpose(A.parent.A), B, α, β)
62+
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractMatrix, α::Number, β::Number) = mul!(C, transpose(A.parent.A), B, α, β)
63+
5664
function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVector,AdjOrTransDenseMatrix}, α::Number, β::Number)
5765
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
5866
size(A, 1) == size(C, 1) || throw(DimensionMismatch())

test/runtests.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,51 @@ using Test
4949
@test norm(ref-out) == 0
5050
end
5151

52+
53+
# These test below are here to ensure we don't hit ambiguity warnings.
54+
# The implementations are not (currently) threaded.
55+
sx = sprand(Bool,n,0.05)
56+
@testset "$(Mat)_L_sparsevec" for Mat in [ThreadedSparseMatrixCSC]
57+
Ct = Mat(C)
58+
59+
out = similar(sx, T, N)
60+
LinearAlgebra.mul!(out, Ct, sx)
61+
ref = C*sx
62+
@test norm(ref-out) == 0
63+
@test typeof(ref)==typeof(out)
64+
end
65+
66+
sx = sprand(Bool,N,0.05)
67+
@testset "$(Mat)_L_$(op)_sparsevec" for op in [adjoint,transpose], Mat in [ThreadedSparseMatrixCSC]
68+
Ct = Mat(C)
69+
70+
out = similar(sx, T, n)
71+
LinearAlgebra.mul!(out, op(Ct), sx)
72+
ref = op(C)*sx
73+
@test norm(ref-out) == 0
74+
@test typeof(ref)==typeof(out)
75+
end
76+
77+
sx = sparse(rand(1:n,10),1:10,true,n,10)
78+
@testset "$(Mat)_L_sparse" for Mat in [ThreadedSparseMatrixCSC]
79+
Ct = Mat(C)
80+
81+
out = similar(sx, T, N, 10)
82+
LinearAlgebra.mul!(out, Ct, sx)
83+
ref = C*sx
84+
@test norm(ref-out) == 0
85+
@test typeof(ref)==typeof(out)
86+
end
87+
88+
sx = sparse(rand(1:N,10),1:10,true,N,10)
89+
@testset "$(Mat)_L_$(op)_sparse" for op in [adjoint,transpose], Mat in [ThreadedSparseMatrixCSC]
90+
Ct = Mat(C)
91+
92+
out = similar(sx, T, n, 10)
93+
LinearAlgebra.mul!(out, op(Ct), sx)
94+
ref = op(C)*sx
95+
@test norm(ref-out) == 0
96+
@test typeof(ref)==typeof(out)
97+
end
98+
5299
end

0 commit comments

Comments
 (0)