Skip to content

Commit 34ece87

Browse files
authored
Extend 3-arg dot to generic HermOrSym sparse matrices (#643)
1 parent 095b685 commit 34ece87

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/linalg.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular,
4-
RealHermSymComplexHerm, checksquare, sym_uplo, wrap
4+
RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
55
using Random: rand!
66

77
const tilebufsize = 10800 # Approximately 32k/3
@@ -1210,6 +1210,9 @@ function nzrangelo(A, i, excl=false)
12101210
@inbounds r2 < r1 || rv[r1] >= i + excl ? r : (searchsortedfirst(view(rv, r1:r2), i + excl) + r1-1):r2
12111211
end
12121212

1213+
dot(x::AbstractVector, A::HermOrSym{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractVector) =
1214+
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint)
1215+
# disambiguation
12131216
dot(x::AbstractVector, A::RealHermSymComplexHerm{<:Real,<:AbstractSparseMatrixCSC}, y::AbstractVector) =
12141217
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint)
12151218
function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector, rangefun::Function, diagop::Function, odiagop::Function)
@@ -1242,9 +1245,12 @@ function _dot(x::AbstractVector, A::AbstractSparseMatrixCSC, y::AbstractVector,
12421245
end
12431246
return r
12441247
end
1245-
dot(x::SparseVector, A::RealHermSymComplexHerm{<:Real,<:AbstractSparseMatrixCSC}, y::SparseVector) =
1246-
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real)
1247-
function _dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector, rangefun::Function, diagop::Function)
1248+
dot(x::AbstractSparseVector, A::HermOrSym{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractSparseVector) =
1249+
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint)
1250+
# disambiguation
1251+
dot(x::AbstractSparseVector, A::RealHermSymComplexHerm{<:Real,<:AbstractSparseMatrixCSC}, y::AbstractSparseVector) =
1252+
_dot(x, parent(A), y, A.uplo == 'U' ? nzrangeup : nzrangelo, A isa Symmetric ? identity : real, A isa Symmetric ? transpose : adjoint)
1253+
function _dot(x::AbstractSparseVector, A::AbstractSparseMatrixCSC, y::AbstractSparseVector, rangefun::Function, diagop::Function, odiagop::Function)
12481254
m, n = size(A)
12491255
length(x) == m && n == length(y) ||
12501256
throw(DimensionMismatch("x has length $(length(x)), A has size ($m, $n), y has length $(length(y))"))
@@ -1275,7 +1281,7 @@ function _dot(x::SparseVector, A::AbstractSparseMatrixCSC, y::SparseVector, rang
12751281
A_ptr_lo = first(rangefun(A, xi, true))
12761282
A_ptr_hi = last(rangefun(A, xi, true))
12771283
if A_ptr_lo <= A_ptr_hi
1278-
r += dot(xv, _spdot((a, y) -> a'y, A_ptr_lo, A_ptr_hi, Arowval, Anzval,
1284+
r += dot(xv, _spdot((a, y) -> odiagop(a)*y, A_ptr_lo, A_ptr_hi, Arowval, Anzval,
12791285
1, length(ynzind), ynzind, ynzval))
12801286
end
12811287
end
@@ -2241,7 +2247,7 @@ end
22412247
# return F
22422248
# end
22432249
# end
2244-
function factorize(A::LinearAlgebra.RealHermSymComplexHerm{Float64,<:AbstractSparseMatrixCSC})
2250+
function factorize(A::RealHermSymComplexHerm{Float64,<:AbstractSparseMatrixCSC})
22452251
F = cholesky(A; check = false)
22462252
if LinearAlgebra.issuccess(F)
22472253
return F

test/linalg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,12 +954,13 @@ end
954954
@test dot(x, A, y) dot(x, Av, y)
955955
end
956956

957-
for (T, trans) in ((Float64, Symmetric), (ComplexF64, Symmetric), (ComplexF64, Hermitian)), uplo in (:U, :L)
957+
for T in (Float64, ComplexF64, Quaternion{Float64}), trans in (Symmetric, Hermitian), uplo in (:U, :L)
958958
B = sprandn(T, 10, 10, 0.2)
959959
x = sprandn(T, 10, 0.4)
960+
xd = Vector(x)
960961
S = trans(B'B, uplo)
961962
Sd = trans(Matrix(B'B), uplo)
962-
@test dot(x, S, x) dot(x, Sd, x) dot(Vector(x), S, Vector(x)) dot(Vector(x), Sd, Vector(x))
963+
@test dot(x, S, x) dot(x, Sd, x) dot(xd, S, xd) dot(xd, Sd, xd)
963964
end
964965
end
965966

0 commit comments

Comments
 (0)