Skip to content

Commit 70c06b1

Browse files
authored
Diagonal-sandwiched triple product for SparseMatrixCSC (#562)
* Diagonal-sandwiched triple product for SparseMatrixCSC * Specialize for SparseMatrixCSC * Test for non-square matrix
1 parent 313a04f commit 70c06b1

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

src/linalg.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,27 @@ const SparseOrTri{Tv,Ti} = Union{SparseMatrixCSCUnion{Tv,Ti},SparseTriangular{Tv
205205
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::SparseOrTri) = spmatmul(copy(A), B)
206206
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(copy(A), copy(B))
207207

208+
(*)(Da::Diagonal, A::Union{SparseMatrixCSCUnion, AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}, Db::Diagonal) = Da * (A * Db)
209+
function (*)(Da::Diagonal, A::SparseMatrixCSC, Db::Diagonal)
210+
(size(Da, 2) == size(A,1) && size(A,2) == size(Db,1)) ||
211+
throw(DimensionMismatch("incompatible sizes"))
212+
T = promote_op(matprod, eltype(Da), promote_op(matprod, eltype(A), eltype(Db)))
213+
dest = similar(A, T)
214+
vals_dest = nonzeros(dest)
215+
rows = rowvals(A)
216+
vals = nonzeros(A)
217+
da, db = map(parent, (Da, Db))
218+
for col in axes(A,2)
219+
dbcol = db[col]
220+
for i in nzrange(A, col)
221+
row = rows[i]
222+
val = vals[i]
223+
vals_dest[i] = da[row] * val * dbcol
224+
end
225+
end
226+
dest
227+
end
228+
208229
# Gustavson's matrix multiplication algorithm revisited.
209230
# The result rowval vector is already sorted by construction.
210231
# The auxiliary Vector{Ti} xb is replaced by a Vector{Bool} of same length.

test/linalg.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,4 +912,24 @@ end
912912
@test sparse(3I, 4, 5) == sparse(1:4, 1:4, 3, 4, 5)
913913
@test sparse(3I, 5, 4) == sparse(1:4, 1:4, 3, 5, 4)
914914
end
915+
916+
@testset "diagonal-sandwiched triple multiplication" begin
917+
S = sprand(4, 6, 0.2)
918+
D1 = Diagonal(axes(S,1))
919+
D2 = Diagonal(axes(S,2) .+ 4)
920+
A = Array(S)
921+
C = D1 * S * D2
922+
@test C isa SparseMatrixCSC
923+
@test C D1 * A * D2
924+
C = D2 * S' * D1
925+
@test C isa SparseMatrixCSC
926+
@test C D2 * A' * D1
927+
C = D1 * view(S, :, :) * D2
928+
@test C isa SparseMatrixCSC
929+
@test C D1 * A * D2
930+
931+
@test_throws DimensionMismatch D2 * S * D2
932+
@test_throws DimensionMismatch D1 * S * D1
933+
end
934+
915935
end

0 commit comments

Comments
 (0)