Skip to content

Commit 348dad9

Browse files
committed
Merge branch 'main' into export_abstractmatrixcsc
2 parents 4fc922a + f10d4da commit 348dad9

File tree

3 files changed

+112
-1
lines changed

3 files changed

+112
-1
lines changed

src/linalg.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,27 @@ const SparseOrTri{Tv,Ti} = Union{SparseMatrixCSCUnion{Tv,Ti},SparseTriangular{Tv
205205
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::SparseOrTri) = spmatmul(copy(A), B)
206206
*(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(copy(A), copy(B))
207207

208+
(*)(Da::Diagonal, A::Union{SparseMatrixCSCUnion, AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}, Db::Diagonal) = Da * (A * Db)
209+
function (*)(Da::Diagonal, A::SparseMatrixCSC, Db::Diagonal)
210+
(size(Da, 2) == size(A,1) && size(A,2) == size(Db,1)) ||
211+
throw(DimensionMismatch("incompatible sizes"))
212+
T = promote_op(matprod, eltype(Da), promote_op(matprod, eltype(A), eltype(Db)))
213+
dest = similar(A, T)
214+
vals_dest = nonzeros(dest)
215+
rows = rowvals(A)
216+
vals = nonzeros(A)
217+
da, db = map(parent, (Da, Db))
218+
for col in axes(A,2)
219+
dbcol = db[col]
220+
for i in nzrange(A, col)
221+
row = rows[i]
222+
val = vals[i]
223+
vals_dest[i] = da[row] * val * dbcol
224+
end
225+
end
226+
dest
227+
end
228+
208229
# Gustavson's matrix multiplication algorithm revisited.
209230
# The result rowval vector is already sorted by construction.
210231
# The auxiliary Vector{Ti} xb is replaced by a Vector{Bool} of same length.
@@ -629,6 +650,62 @@ function dot(A::AbstractSparseMatrixCSC, B::Union{DenseMatrixUnion,WrapperMatrix
629650
return conj(dot(B, A))
630651
end
631652

653+
function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractVector)
654+
d = D.diag
655+
if length(x) != length(y) || length(y) != length(d)
656+
throw(
657+
DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), D has side dimension $(length(d))")
658+
)
659+
end
660+
nzvals = nonzeros(x)
661+
nzinds = nonzeroinds(x)
662+
s = zero(typeof(dot(first(x), first(D), first(y))))
663+
@inbounds for nzidx in eachindex(nzvals)
664+
s += dot(nzvals[nzidx], d[nzinds[nzidx]], y[nzinds[nzidx]])
665+
end
666+
return s
667+
end
668+
669+
dot(x::AbstractVector, D::Diagonal, y::AbstractSparseVector) = adjoint(dot(y, D', x))
670+
671+
function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector)
672+
d = D.diag
673+
if length(y) != length(x) || length(y) != length(d)
674+
throw(
675+
DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))")
676+
)
677+
end
678+
xnzind = nonzeroinds(x)
679+
ynzind = nonzeroinds(y)
680+
xnzval = nonzeros(x)
681+
ynzval = nonzeros(y)
682+
s = zero(typeof(dot(first(x), first(D), first(y))))
683+
if isempty(xnzind) || isempty(ynzind)
684+
return s
685+
end
686+
687+
x_idx = 1
688+
y_idx = 1
689+
x_idx_last = length(xnzind)
690+
y_idx_last = length(ynzind)
691+
692+
# go through the nonzero indices of a and b simultaneously
693+
@inbounds while x_idx <= x_idx_last && y_idx <= y_idx_last
694+
ix = xnzind[x_idx]
695+
iy = ynzind[y_idx]
696+
if ix == iy
697+
s += dot(xnzval[x_idx], d[ix], ynzval[y_idx])
698+
x_idx += 1
699+
y_idx += 1
700+
elseif ix < iy
701+
x_idx += 1
702+
else
703+
y_idx += 1
704+
end
705+
end
706+
return s
707+
end
708+
632709
## triangular sparse handling
633710
## triangular multiplication
634711
function LinearAlgebra.generic_trimatmul!(C::StridedVecOrMat, uploc, isunitc, tfun::Function, A::SparseMatrixCSCUnion, B::AbstractVecOrMat)

src/solvers/spqr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ Matrix{T}(Q::QRSparseQ) where {T} = lmul!(Q, Matrix{T}(I, size(Q, 1), min(size(Q
146146

147147
# From SPQR manual p. 6
148148
_default_tol(A::AbstractSparseMatrixCSC) =
149-
20*sum(size(A))*eps(real(eltype(A)))*maximum(norm(view(A, :, i)) for i in 1:size(A, 2))
149+
20*sum(size(A))*eps()*maximum(norm(view(A, :, i)) for i in 1:size(A, 2))
150150

151151
"""
152152
qr(A::SparseMatrixCSC; tol=_default_tol(A), ordering=ORDERING_DEFAULT) -> QRSparse

test/linalg.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,20 @@ end
850850
@test dot(TA,WB) dot(Matrix(TA), WB)
851851
@test dot(TA,WC) dot(Matrix(TA), WC)
852852
end
853+
for M in (A, B, C)
854+
D = Diagonal(M * M')
855+
a = spzeros(Complex{Float64}, size(D, 1))
856+
a[1:3] = rand(Complex{Float64}, 3)
857+
b = spzeros(Complex{Float64}, size(D, 1))
858+
b[1:3] = rand(Complex{Float64}, 3)
859+
@test dot(a, D, b) dot(a, sparse(D), b)
860+
@test dot(b, D, a) dot(b, sparse(D), a)
861+
@test dot(b, D, a) dot(b, D, collect(a))
862+
@test dot(b, D, a) dot(collect(b), D, a)
863+
@test_throws DimensionMismatch dot(b, D, [a; 1])
864+
@test_throws DimensionMismatch dot([b; 1], D, a)
865+
@test_throws DimensionMismatch dot([b; 1], D, [a; 1])
866+
end
853867
end
854868
@test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2))
855869
@test_throws DimensionMismatch dot(rand(5,5),sprand(5,6,0.2))
@@ -912,4 +926,24 @@ end
912926
@test sparse(3I, 4, 5) == sparse(1:4, 1:4, 3, 4, 5)
913927
@test sparse(3I, 5, 4) == sparse(1:4, 1:4, 3, 5, 4)
914928
end
929+
930+
@testset "diagonal-sandwiched triple multiplication" begin
931+
S = sprand(4, 6, 0.2)
932+
D1 = Diagonal(axes(S,1))
933+
D2 = Diagonal(axes(S,2) .+ 4)
934+
A = Array(S)
935+
C = D1 * S * D2
936+
@test C isa SparseMatrixCSC
937+
@test C D1 * A * D2
938+
C = D2 * S' * D1
939+
@test C isa SparseMatrixCSC
940+
@test C D2 * A' * D1
941+
C = D1 * view(S, :, :) * D2
942+
@test C isa SparseMatrixCSC
943+
@test C D1 * A * D2
944+
945+
@test_throws DimensionMismatch D2 * S * D2
946+
@test_throws DimensionMismatch D1 * S * D1
947+
end
948+
915949
end

0 commit comments

Comments
 (0)