Skip to content

Commit bb5ecc0

Browse files
authored
fast quadratic form for dense matrix, sparse vectors (#640)
1 parent 34ece87 commit bb5ecc0

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

src/linalg.jl

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ function dot(x::AbstractVector{T1}, A::AbstractSparseMatrixCSC{T2}, y::AbstractV
605605
end
606606
return s
607607
end
608-
function dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector)
608+
function dot(x::AbstractSparseVector, A::AbstractSparseMatrixCSC, y::AbstractSparseVector)
609609
m, n = size(A)
610610
length(x) == m && n == length(y) ||
611611
throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))"))
@@ -723,6 +723,71 @@ function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector)
723723
return s
724724
end
725725

726+
function dot(
727+
a::AbstractSparseVector,
728+
Q::Union{DenseMatrixUnion,WrapperMatrixTypes{<:Any,<:DenseMatrixUnion}},
729+
b::AbstractSparseVector,
730+
)
731+
return _dot_quadratic_form(a, Q, b)
732+
end
733+
734+
function dot(
735+
a::AbstractSparseVector,
736+
Q::LinearAlgebra.Transpose{<:Real,<:DenseMatrixUnion},
737+
b::AbstractSparseVector,
738+
)
739+
return _dot_quadratic_form(a, Q, b)
740+
end
741+
742+
function dot(
743+
a::AbstractSparseVector,
744+
Q::LinearAlgebra.Transpose{<:Real,<:WrapperMatrixTypes{<:Real,<:DenseMatrixUnion}},
745+
b::AbstractSparseVector,
746+
)
747+
return _dot_quadratic_form(a, Q, b)
748+
end
749+
750+
function dot(
751+
a::AbstractSparseVector,
752+
Q::LinearAlgebra.RealHermSymComplexHerm{<:Real,<:DenseMatrixUnion},
753+
b::AbstractSparseVector)
754+
return _dot_quadratic_form(a, Q, b)
755+
end
756+
757+
function dot(
758+
a::AbstractSparseVector,
759+
Q::Union{
760+
LinearAlgebra.Hermitian{<:Real,<:DenseMatrixUnion}, LinearAlgebra.Symmetric{<:Real,<:DenseMatrixUnion}
761+
},
762+
b::AbstractSparseVector)
763+
return _dot_quadratic_form(a, Q, b)
764+
end
765+
766+
# actual function implementation called by the method dispatch
767+
function _dot_quadratic_form(a, Q, b)
768+
n = length(a)
769+
m = length(b)
770+
if size(Q) != (n, m)
771+
throw(DimensionMismatch("Matrix has a size $(size(Q)) but vectors have length $n, $m"))
772+
end
773+
anzind = nonzeroinds(a)
774+
bnzind = nonzeroinds(b)
775+
anzval = nonzeros(a)
776+
bnzval = nonzeros(b)
777+
s = zero(Base.promote_eltype(a, Q, b))
778+
if isempty(anzind) || isempty(bnzind)
779+
return s
780+
end
781+
@inbounds for a_idx in eachindex(anzind)
782+
for b_idx in eachindex(bnzind)
783+
ia = anzind[a_idx]
784+
ib = bnzind[b_idx]
785+
s += dot(anzval[a_idx], Q[ia, ib], bnzval[b_idx])
786+
end
787+
end
788+
return s
789+
end
790+
726791
## triangular sparse handling
727792
## triangular multiplication
728793
function LinearAlgebra.generic_trimatmul!(C::StridedVecOrMat, uploc, isunitc, tfun::Function, A::SparseMatrixCSCUnion, B::AbstractVecOrMat)

test/linalg.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,16 @@ end
952952
y = sprand(ComplexF64, 15, 0.5)
953953
@test dot(x, A, y) dot(Vector(x), A, Vector(y)) (Vector(x)' * Matrix(A)) * Vector(y)
954954
@test dot(x, A, y) dot(x, Av, y)
955+
@test dot(x, collect(A), y) dot(x, A, y)
956+
@test dot(y, collect(A)', x) dot(y, A', x)
957+
@test dot(y, transpose(collect(A)), x) dot(y, transpose(A), x)
958+
@test dot(y, Hermitian(collect(A)' * collect(A)), y) dot(y, Hermitian(A' * A), y)
959+
@test dot(y, Symmetric(collect(A)' * collect(A)), y) dot(y, Symmetric(A' * A), y)
960+
B = BitMatrix(rand(Bool, 10, 15))
961+
@test dot(x, A, y) dot(x, Matrix(A), y)
962+
@test_throws DimensionMismatch dot([x, x], A, y)
963+
@test_throws DimensionMismatch dot(x, A, [y, y])
964+
@test iszero(dot(spzeros(length(x)), A, y))
955965
end
956966

957967
for T in (Float64, ComplexF64, Quaternion{Float64}), trans in (Symmetric, Hermitian), uplo in (:U, :L)

0 commit comments

Comments
 (0)