From 69ee70b3c7b01c186fa2d716af958cab0f91fef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Fri, 11 Oct 2024 20:24:01 +0200 Subject: [PATCH 1/9] added specialized method --- src/linalg.jl | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++ test/linalg.jl | 9 +++++++++ 2 files changed, 64 insertions(+) diff --git a/src/linalg.jl b/src/linalg.jl index 131a21bc..21eaa904 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -629,6 +629,61 @@ function dot(A::AbstractSparseMatrixCSC, B::Union{DenseMatrixUnion,WrapperMatrix return conj(dot(B, A)) end +function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector) + if length(x) != length(y) + throw( + DimensionMismatch("Vector x has a length $(length(x)) but y has a length $(length(y))") + ) + end + d = Q.diag + nzvals = nonzeros(x) + nzinds = nonzeroinds(x) + s = zero(Base.promote_eltype(x, Q, y)) + @inbounds for nzidx in eachindex(nzvals) + s += dot(nzvals[nzidx], d[nzinds[nzidx]], y[nzinds[nzidx]]) + end + return s +end + +function dot(a::AbstractSparseVector, Q::Diagonal, b::AbstractSparseVector) + n = length(a) + if length(b) != n + throw( + DimensionMismatch("Vector a has a length $n but b has a length $(length(b))") + ) + end + anzind = nonzeroinds(a) + bnzind = nonzeroinds(b) + anzval = nonzeros(a) + bnzval = nonzeros(b) + s = zero(Base.promote_eltype(a, Q, b)) + + if isempty(anzind) || isempty(bnzind) + return s + end + + a_idx = 1 + b_idx = 1 + a_idx_last = length(anzind) + b_idx_last = length(bnzind) + + # go through the nonzero indices of a and b simultaneously + @inbounds while a_idx <= a_idx_last && b_idx <= b_idx_last + ia = anzind[a_idx] + ib = bnzind[b_idx] + if ia == ib + s += dot(anzval[a_idx], Q.diag[ia], bnzval[b_idx]) + a_idx += 1 + b_idx += 1 + elseif ia < ib + a_idx += 1 + else + b_idx += 1 + end + end + return s +end + ## triangular sparse handling ## triangular multiplication function LinearAlgebra.generic_trimatmul!(C::StridedVecOrMat, uploc, isunitc, tfun::Function, A::SparseMatrixCSCUnion, B::AbstractVecOrMat) diff --git a/test/linalg.jl b/test/linalg.jl index 45d42d9f..9e6c3876 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -850,6 +850,15 @@ end @test dot(TA,WB) ≈ dot(Matrix(TA), WB) @test dot(TA,WC) ≈ dot(Matrix(TA), WC) end + for M in (A, B, C) + D = Diagonal(M * M') + a = spzeros(size(D, 1)) + a[1:3] = rand(Complex{Float64}, 3) + b = spzeros(size(D, 1)) + b[1:3] = rand(Complex{Float64}, 3) + @test dot(a, D, b) ≈ dot(a, sparse(D), b) + @test dot(b, D, a) ≈ dot(b, sparse(D), a) + end end @test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2)) @test_throws DimensionMismatch dot(rand(5,5),sprand(5,6,0.2)) From 8c87396ed1fb4054e3aaaaad01a7fa75560ebcb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Fri, 11 Oct 2024 21:09:14 +0200 Subject: [PATCH 2/9] constructor needs complex too --- test/linalg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/linalg.jl b/test/linalg.jl index 9e6c3876..0b8c682d 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -852,9 +852,9 @@ end end for M in (A, B, C) D = Diagonal(M * M') - a = spzeros(size(D, 1)) + a = spzeros(Complex{Float64}, size(D, 1)) a[1:3] = rand(Complex{Float64}, 3) - b = spzeros(size(D, 1)) + b = spzeros(Complex{Float64}, size(D, 1)) b[1:3] = rand(Complex{Float64}, 3) @test dot(a, D, b) ≈ dot(a, sparse(D), b) @test dot(b, D, a) ≈ dot(b, sparse(D), a) From 7833f1483fa53305d06b7b55cbfff438bc2e8a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 13 Oct 2024 11:55:12 +0200 Subject: [PATCH 3/9] more tests --- test/linalg.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/linalg.jl b/test/linalg.jl index 0b8c682d..acbe723b 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -858,6 +858,8 @@ end b[1:3] = rand(Complex{Float64}, 3) @test dot(a, D, b) ≈ dot(a, sparse(D), b) @test dot(b, D, a) ≈ dot(b, sparse(D), a) + @test dot(b, D, a) ≈ dot(b, D, collect(a)) + @test dot(b, D, a) ≈ dot(collect(b), D, a) end end @test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2)) From bfcaad63212ffd48461e4b08677f0d8ac201f77f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Tue, 15 Oct 2024 15:36:08 +0200 Subject: [PATCH 4/9] test dimension --- src/linalg.jl | 54 +++++++++++++++++++++++++------------------------- test/linalg.jl | 3 +++ 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 21eaa904..ec7f0627 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -630,12 +630,12 @@ function dot(A::AbstractSparseMatrixCSC, B::Union{DenseMatrixUnion,WrapperMatrix end function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector) - if length(x) != length(y) + d = Q.diag + if length(x) != length(y) || length(y) != length(d) throw( - DimensionMismatch("Vector x has a length $(length(x)) but y has a length $(length(y))") + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))") ) end - d = Q.diag nzvals = nonzeros(x) nzinds = nonzeroinds(x) s = zero(Base.promote_eltype(x, Q, y)) @@ -645,40 +645,40 @@ function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector) return s end -function dot(a::AbstractSparseVector, Q::Diagonal, b::AbstractSparseVector) - n = length(a) - if length(b) != n +function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractSparseVector) + n = length(x) + if length(y) != n || n != size(Q, 1) throw( - DimensionMismatch("Vector a has a length $n but b has a length $(length(b))") + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))") ) end - anzind = nonzeroinds(a) - bnzind = nonzeroinds(b) - anzval = nonzeros(a) - bnzval = nonzeros(b) - s = zero(Base.promote_eltype(a, Q, b)) + xnzind = nonzeroinds(x) + ynzind = nonzeroinds(y) + xnzval = nonzeros(x) + ynzval = nonzeros(y) + s = zero(Base.promote_eltype(x, Q, y)) - if isempty(anzind) || isempty(bnzind) + if isempty(xnzind) || isempty(ynzind) return s end - a_idx = 1 - b_idx = 1 - a_idx_last = length(anzind) - b_idx_last = length(bnzind) + x_idx = 1 + y_idx = 1 + x_idx_last = length(xnzind) + y_idx_last = length(ynzind) # go through the nonzero indices of a and b simultaneously - @inbounds while a_idx <= a_idx_last && b_idx <= b_idx_last - ia = anzind[a_idx] - ib = bnzind[b_idx] - if ia == ib - s += dot(anzval[a_idx], Q.diag[ia], bnzval[b_idx]) - a_idx += 1 - b_idx += 1 - elseif ia < ib - a_idx += 1 + @inbounds while x_idx <= x_idx_last && y_idx <= y_idx_last + ix = xnzind[x_idx] + iy = ynzind[y_idx] + if ix == iy + s += dot(xnzval[x_idx], Q.diag[ix], ynzval[y_idx]) + x_idx += 1 + y_idx += 1 + elseif ix < iy + x_idx += 1 else - b_idx += 1 + y_idx += 1 end end return s diff --git a/test/linalg.jl b/test/linalg.jl index acbe723b..bdced63b 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -860,6 +860,9 @@ end @test dot(b, D, a) ≈ dot(b, sparse(D), a) @test dot(b, D, a) ≈ dot(b, D, collect(a)) @test dot(b, D, a) ≈ dot(collect(b), D, a) + @test_throws DimensionMismatch dot(b, D, [a; 1]) + @test_throws DimensionMismatch dot([b; 1], D, a) + @test_throws DimensionMismatch dot([b; 1], D, [a; 1]) end end @test_throws DimensionMismatch dot(sprand(5,5,0.2),sprand(5,6,0.2)) From eb346bcb03c43de484f45703578e4182afe11ec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Tue, 15 Oct 2024 21:23:23 +0200 Subject: [PATCH 5/9] fix error --- src/linalg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index ec7f0627..8d53c1dd 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -633,7 +633,7 @@ function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector) d = Q.diag if length(x) != length(y) || length(y) != length(d) throw( - DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))") + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(size(Q, 1))") ) end nzvals = nonzeros(x) @@ -649,7 +649,7 @@ function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractSparseVector) n = length(x) if length(y) != n || n != size(Q, 1) throw( - DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))") + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(size(Q, 1))") ) end xnzind = nonzeroinds(x) From a10cb5bb83fc610071a18eff9b12721d79ba42c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Wed, 16 Oct 2024 14:16:24 +0200 Subject: [PATCH 6/9] homogenize notation --- src/linalg.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 8d53c1dd..a7804a1c 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -645,18 +645,18 @@ function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector) return s end -function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractSparseVector) - n = length(x) - if length(y) != n || n != size(Q, 1) +function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector) + d = D.diag + if length(y) != length(x) || length(y) != length(d) throw( - DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(size(Q, 1))") + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(length(d))") ) end xnzind = nonzeroinds(x) ynzind = nonzeroinds(y) xnzval = nonzeros(x) ynzval = nonzeros(y) - s = zero(Base.promote_eltype(x, Q, y)) + s = zero(Base.promote_eltype(x, D, y)) if isempty(xnzind) || isempty(ynzind) return s @@ -672,7 +672,7 @@ function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractSparseVector) ix = xnzind[x_idx] iy = ynzind[y_idx] if ix == iy - s += dot(xnzval[x_idx], Q.diag[ix], ynzval[y_idx]) + s += dot(xnzval[x_idx], d[ix], ynzval[y_idx]) x_idx += 1 y_idx += 1 elseif ix < iy From bd2b0a421971b9cfd2f2b699cca8cb3776e44915 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Wed, 16 Oct 2024 14:17:31 +0200 Subject: [PATCH 7/9] generic recursive zero value --- src/linalg.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index a7804a1c..87ae55af 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -656,8 +656,7 @@ function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector) ynzind = nonzeroinds(y) xnzval = nonzeros(x) ynzval = nonzeros(y) - s = zero(Base.promote_eltype(x, D, y)) - + s = zero(typeof(dot(first(x), first(D), first(y)))) if isempty(xnzind) || isempty(ynzind) return s end From e6cc8b6feb9abd80f942986eefa4ed494fcd5d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Wed, 16 Oct 2024 14:19:15 +0200 Subject: [PATCH 8/9] apply changes to other function --- src/linalg.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 87ae55af..5fe15c35 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -629,16 +629,16 @@ function dot(A::AbstractSparseMatrixCSC, B::Union{DenseMatrixUnion,WrapperMatrix return conj(dot(B, A)) end -function dot(x::AbstractSparseVector, Q::Diagonal, y::AbstractVector) - d = Q.diag +function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractVector) + d = D.diag if length(x) != length(y) || length(y) != length(d) throw( - DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), Q has side dimension $(size(Q, 1))") + DimensionMismatch("Vectors and matrix have different dimensions, x has a length $(length(x)), y has a length $(length(y)), D has side dimension $(length(d))") ) end nzvals = nonzeros(x) nzinds = nonzeroinds(x) - s = zero(Base.promote_eltype(x, Q, y)) + s = zero(typeof(dot(first(x), first(D), first(y)))) @inbounds for nzidx in eachindex(nzvals) s += dot(nzvals[nzidx], d[nzinds[nzidx]], y[nzinds[nzidx]]) end From 031c5fdc962e9c59ed474d535e9f50cc86c72f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 21 Oct 2024 16:34:39 +0200 Subject: [PATCH 9/9] added second vec sparse --- src/linalg.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/linalg.jl b/src/linalg.jl index 5fe15c35..8a3174d0 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -645,6 +645,8 @@ function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractVector) return s end +dot(x::AbstractVector, D::Diagonal, y::AbstractSparseVector) = adjoint(dot(y, D', x)) + function dot(x::AbstractSparseVector, D::Diagonal, y::AbstractSparseVector) d = D.diag if length(y) != length(x) || length(y) != length(d)