Skip to content

Commit f10d4da

Browse files
authored
added specialized method for 3-argument dot with diagonal matrix (#565)
1 parent 70c06b1 commit f10d4da

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

src/linalg.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,62 @@ function dot(A::AbstractSparseMatrixCSC, B::Union{DenseMatrixUnion,WrapperMatrix
650650
return conj(dot(B, A))
651651
end
652652

653+
function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractVector)
654+
d = D.diag
655+
if length(x) != length(y) || length(y) != length(d)
656+
throw(
657+
DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), D has side dimension $(length(d))")
658+
)
659+
end
660+
nzvals = nonzeros(x)
661+
nzinds = nonzeroinds(x)
662+
s = zero(typeof(dot(first(x), first(D), first(y))))
663+
@inbounds for nzidx in eachindex(nzvals)
664+
s += dot(nzvals[nzidx], d[nzinds[nzidx]], y[nzinds[nzidx]])
665+
end
666+
return s
667+
end
668+
669+
dot(x::AbstractVector, D::Diagonal, y::AbstractSparseVector) = adjoint(dot(y, D', x))
670+
671+
function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector)
672+
d = D.diag
673+
if length(y) != length(x) || length(y) != length(d)
674+
throw(
675+
DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))")
676+
)
677+
end
678+
xnzind = nonzeroinds(x)
679+
ynzind = nonzeroinds(y)
680+
xnzval = nonzeros(x)
681+
ynzval = nonzeros(y)
682+
s = zero(typeof(dot(first(x), first(D), first(y))))
683+
if isempty(xnzind) || isempty(ynzind)
684+
return s
685+
end
686+
687+
x_idx = 1
688+
y_idx = 1
689+
x_idx_last = length(xnzind)
690+
y_idx_last = length(ynzind)
691+
692+
# go through the nonzero indices of a and b simultaneously
693+
@inbounds while x_idx <= x_idx_last && y_idx <= y_idx_last
694+
ix = xnzind[x_idx]
695+
iy = ynzind[y_idx]
696+
if ix == iy
697+
s += dot(xnzval[x_idx], d[ix], ynzval[y_idx])
698+
x_idx += 1
699+
y_idx += 1
700+
elseif ix < iy
701+
x_idx += 1
702+
else
703+
y_idx += 1
704+
end
705+
end
706+
return s
707+
end
708+
653709
## triangular sparse handling
654710
## triangular multiplication
655711
function LinearAlgebra.generic_trimatmul!(C::StridedVecOrMat, uploc, isunitc, tfun::Function, A::SparseMatrixCSCUnion, B::AbstractVecOrMat)

test/linalg.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,20 @@ end
850850
@test dot(TA,WB) dot(Matrix(TA), WB)
851851
@test dot(TA,WC) dot(Matrix(TA), WC)
852852
end
853+
for M in (A, B, C)
854+
D = Diagonal(M * M')
855+
a = spzeros(Complex{Float64}, size(D, 1))
856+
a[1:3] = rand(Complex{Float64}, 3)
857+
b = spzeros(Complex{Float64}, size(D, 1))
858+
b[1:3] = rand(Complex{Float64}, 3)
859+
@test dot(a, D, b) dot(a, sparse(D), b)
860+
@test dot(b, D, a) dot(b, sparse(D), a)
861+
@test dot(b, D, a) dot(b, D, collect(a))
862+
@test dot(b, D, a) dot(collect(b), D, a)
863+
@test_throws DimensionMismatch dot(b, D, [a; 1])
864+
@test_throws DimensionMismatch dot([b; 1], D, a)
865+
@test_throws DimensionMismatch dot([b; 1], D, [a; 1])
866+
end
853867
end
854868
@test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2))
855869
@test_throws DimensionMismatch dot(rand(5,5),sprand(5,6,0.2))

0 commit comments

Comments
 (0)