Skip to content

Commit 2a36728

Browse files
committed
Specialize for SparseMatrixCSC
1 parent acc046f commit 2a36728

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

src/linalg.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,24 @@ const SparseOrTri{Tv,Ti} = Union{SparseMatrixCSCUnion{Tv,Ti},SparseTriangular{Tv
205205
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::SparseOrTri) = spmatmul(copy(A), B)
206206
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(copy(A), copy(B))
207207

208-
(*)(Da::Diagonal, A::Union{SparseMatrixCSCUnion, AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}, Db::Diagonal) =
209-
Da * (A * Db)
208+
(*)(Da::Diagonal, A::Union{SparseMatrixCSCUnion, AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}, Db::Diagonal) = Da * (A * Db)
209+
function (*)(Da::Diagonal, A::SparseMatrixCSC, Db::Diagonal)
210+
T = promote_op(matprod, eltype(Da), promote_op(matprod, eltype(A), eltype(Db)))
211+
dest = similar(A, T)
212+
vals_dest = nonzeros(dest)
213+
rows = rowvals(A)
214+
vals = nonzeros(A)
215+
da, db = map(parent, (Da, Db))
216+
for col in axes(A,2)
217+
dbcol = db[col]
218+
for i in nzrange(A, col)
219+
row = rows[i]
220+
val = vals[i]
221+
vals_dest[i] = da[row] * val * dbcol
222+
end
223+
end
224+
dest
225+
end
210226

211227
# Gustavson's matrix multiplication algorithm revisited.
212228
# The result rowval vector is already sorted by construction.

test/linalg.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -914,18 +914,19 @@ end
914914
end
915915

916916
@testset "diagonal-sandwiched triple multiplication" begin
917-
D = Diagonal(1:4)
918-
S = sprand(Int, 4, 4, 0.2)
917+
D1 = Diagonal(1:4)
918+
D2 = Diagonal(2:2:8)
919+
S = sprand(4, 4, 0.2)
919920
A = Array(S)
920-
C = D * S * D
921+
C = D1 * S * D2
921922
@test C isa SparseMatrixCSC
922-
@test C D * A * D
923-
C = D * S' * D
923+
@test C D1 * A * D2
924+
C = D1 * S' * D2
924925
@test C isa SparseMatrixCSC
925-
@test C D * A' * D
926-
C = D * view(S, :, :) * D
926+
@test C D1 * A' * D2
927+
C = D1 * view(S, :, :) * D2
927928
@test C isa SparseMatrixCSC
928-
@test C D * A * D
929+
@test C D1 * A * D2
929930
end
930931

931932
end

0 commit comments

Comments
 (0)