diff --git a/src/linalg.jl b/src/linalg.jl index e3207cd4..ad83fc51 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -605,7 +605,7 @@ function dot(x::AbstractVector{T1}, A::AbstractSparseMatrixCSC{T2}, y::AbstractV end return s end -function dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector) +function dot(x::AbstractSparseVector, A::AbstractSparseMatrixCSC, y::AbstractSparseVector) m, n = size(A) length(x) == m && n == length(y) || throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))")) @@ -723,6 +723,71 @@ function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector) return s end +function dot( + a::AbstractSparseVector, + Q::Union{DenseMatrixUnion,WrapperMatrixTypes{<:Any,<:DenseMatrixUnion}}, + b::AbstractSparseVector, +) + return _dot_quadratic_form(a, Q, b) +end + +function dot( + a::AbstractSparseVector, + Q::LinearAlgebra.Transpose{<:Real,<:DenseMatrixUnion}, + b::AbstractSparseVector, +) + return _dot_quadratic_form(a, Q, b) +end + +function dot( + a::AbstractSparseVector, + Q::LinearAlgebra.Transpose{<:Real,<:WrapperMatrixTypes{<:Real,<:DenseMatrixUnion}}, + b::AbstractSparseVector, +) + return _dot_quadratic_form(a, Q, b) +end + +function dot( + a::AbstractSparseVector, + Q::LinearAlgebra.RealHermSymComplexHerm{<:Real,<:DenseMatrixUnion}, + b::AbstractSparseVector) + return _dot_quadratic_form(a, Q, b) +end + +function dot( + a::AbstractSparseVector, + Q::Union{ + LinearAlgebra.Hermitian{<:Real,<:DenseMatrixUnion}, LinearAlgebra.Symmetric{<:Real,<:DenseMatrixUnion} + }, + b::AbstractSparseVector) + return _dot_quadratic_form(a, Q, b) +end + +# actual function implementation called by the method dispatch +function _dot_quadratic_form(a, Q, b) + n = length(a) + m = length(b) + if size(Q) != (n, m) + throw(DimensionMismatch("Matrix has a size $(size(Q)) but vectors have length $n, $m")) + end + anzind = nonzeroinds(a) + bnzind = nonzeroinds(b) + anzval = nonzeros(a) + bnzval = nonzeros(b) + s = zero(Base.promote_eltype(a, Q, b)) + if isempty(anzind) || isempty(bnzind) + return s + end + @inbounds for a_idx in eachindex(anzind) + for b_idx in eachindex(bnzind) + ia = anzind[a_idx] + ib = bnzind[b_idx] + s += dot(anzval[a_idx], Q[ia, ib], bnzval[b_idx]) + end + end + return s +end + ## triangular sparse handling ## triangular multiplication function LinearAlgebra.generic_trimatmul!(C::StridedVecOrMat, uploc, isunitc, tfun::Function, A::SparseMatrixCSCUnion, B::AbstractVecOrMat) diff --git a/test/linalg.jl b/test/linalg.jl index a48d17b6..ce4f4e31 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -952,6 +952,16 @@ end y = sprand(ComplexF64, 15, 0.5) @test dot(x, A, y) ≈ dot(Vector(x), A, Vector(y)) ≈ (Vector(x)' * Matrix(A)) * Vector(y) @test dot(x, A, y) ≈ dot(x, Av, y) + @test dot(x, collect(A), y) ≈ dot(x, A, y) + @test dot(y, collect(A)', x) ≈ dot(y, A', x) + @test dot(y, transpose(collect(A)), x) ≈ dot(y, transpose(A), x) + @test dot(y, Hermitian(collect(A)' * collect(A)), y) ≈ dot(y, Hermitian(A' * A), y) + @test dot(y, Symmetric(collect(A)' * collect(A)), y) ≈ dot(y, Symmetric(A' * A), y) + B = BitMatrix(rand(Bool, 10, 15)) + @test dot(x, A, y) ≈ dot(x, Matrix(A), y) + @test_throws DimensionMismatch dot([x, x], A, y) + @test_throws DimensionMismatch dot(x, A, [y, y]) + @test iszero(dot(spzeros(length(x)), A, y)) end for T in (Float64, ComplexF64, Quaternion{Float64}), trans in (Symmetric, Hermitian), uplo in (:U, :L)