Skip to content

Commit a3116b9

Browse files
authored
Restrict AbstractQ multiplication to LinearAlgebra types (#321)
1 parent 31b491e commit a3116b9

File tree

4 files changed

+25
-9
lines changed

4 files changed

+25
-9
lines changed

src/SparseArrays.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using Base: ReshapedArray, promote_op, setindex_shape_check, to_shape, tail,
99
require_one_based_indexing, promote_eltype
1010
using Base.Order: Forward
1111
using LinearAlgebra
12-
using LinearAlgebra: AdjOrTrans, matprod, AbstractQ
12+
using LinearAlgebra: AdjOrTrans, matprod, AbstractQ, HessenbergQ, QRCompactWYQ, QRPackedQ,
13+
LQPackedQ
1314

1415

1516
import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds
@@ -31,6 +32,10 @@ export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
3132
sprand, sprandn, spzeros, nnz, permute, findnz, fkeep!, ftranspose!,
3233
sparse_hcat, sparse_vcat, sparse_hvcat
3334

35+
const AdjQType = isdefined(LinearAlgebra, :AdjointQ) ? LinearAlgebra.AdjointQ : Adjoint
36+
37+
const LinAlgLeftQs = Union{HessenbergQ,QRCompactWYQ,QRPackedQ}
38+
3439
# helper function needed in sparsematrix, sparsevector and higherorderfns
3540
# `iszero` and `!iszero` don't guarantee to return a boolean but we need one that does
3641
# to remove the handle the structure of the array.

src/solvers/spqr.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using LinearAlgebra
88
using LinearAlgebra: AbstractQ, copy_similar
99
using ..LibSuiteSparse: SuiteSparseQR_C
1010

11-
const AdjQType = isdefined(LinearAlgebra, :AdjointQ) ? LinearAlgebra.AdjointQ : Adjoint
1211
const AbstractQType = isdefined(LinearAlgebra, :AdjointQ) ? AbstractQ : AbstractMatrix
1312

1413
# ordering options */
@@ -32,7 +31,7 @@ const ORDERINGS = [ORDERING_FIXED, ORDERING_NATURAL, ORDERING_COLAMD, ORDERING_C
3231
# the best of AMD and METIS. METIS is not tried if it isn't installed.
3332

3433
using ..SparseArrays
35-
using ..SparseArrays: getcolptr, FixedSparseCSC, AbstractSparseMatrixCSC, _unsafe_unfix
34+
using ..SparseArrays: getcolptr, FixedSparseCSC, AbstractSparseMatrixCSC, _unsafe_unfix, AdjQType
3635
using ..CHOLMOD
3736
using ..CHOLMOD: change_stype!, free!
3837

src/sparsematrix.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,17 @@ function _show_with_braille_patterns(io::IO, S::AbstractSparseMatrixCSCInclAdjoi
469469
foreach(c -> print(io, Char(c)), @view brailleGrid[1:end-1])
470470
end
471471

472-
(*)(Q::AbstractQ, B::AbstractSparseMatrixCSC) = Q * Matrix(B)
473-
(*)(Q::AbstractQ, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = Q * copy(B)
474-
(*)(A::AbstractSparseMatrixCSC, Q::AbstractQ) = Matrix(A) * Q
475-
(*)(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, Q::AbstractQ) = copy(A) * Q
472+
for QT in (:LinAlgLeftQs, :LQPackedQ)
473+
@eval (*)(Q::$QT, B::AbstractSparseMatrixCSC) = Q * Matrix(B)
474+
@eval (*)(Q::$QT, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = Q * copy(B)
475+
@eval (*)(A::AbstractSparseMatrixCSC, Q::$QT) = Matrix(A) * Q
476+
@eval (*)(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, Q::$QT) = copy(A) * Q
477+
478+
@eval (*)(Q::AdjQType{<:Any,<:$QT}, B::AbstractSparseMatrixCSC) = Q * Matrix(B)
479+
@eval (*)(Q::AdjQType{<:Any,<:$QT}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = Q * copy(B)
480+
@eval (*)(A::AbstractSparseMatrixCSC, Q::AdjQType{<:Any,<:$QT}) = Matrix(A) * Q
481+
@eval (*)(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, Q::AdjQType{<:Any,<:$QT}) = copy(A) * Q
482+
end
476483

477484
## Reshape
478485

src/sparsevector.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,8 +1292,13 @@ end
12921292
# zero-preserving functions (z->z, nz->nz)
12931293
-(x::SparseVector) = SparseVector(length(x), copy(nonzeroinds(x)), -nonzeros(x))
12941294

1295-
(*)(Q::AbstractQ, B::AbstractSparseVector) = Q * Vector(B)
1296-
(*)(A::AbstractSparseVector, Q::AbstractQ) = Vector(A) * Q
1295+
for QT in (:LinAlgLeftQs, :LQPackedQ)
1296+
@eval (*)(Q::$QT, B::AbstractSparseVector) = Q * Vector(B)
1297+
@eval (*)(Q::AdjQType{<:Any,<:$QT}, B::AbstractSparseVector) = Q * Vector(B)
1298+
1299+
@eval (*)(A::AbstractSparseVector, Q::$QT) = Vector(A) * Q
1300+
@eval (*)(A::AbstractSparseVector, Q::AdjQType{<:Any,<:$QT}) = Vector(A) * Q
1301+
end
12971302

12981303
# functions f, such that
12991304
# f(x) can be zero or non-zero when x != 0

0 commit comments

Comments
 (0)