Skip to content

Commit 2c8b8b1

Browse files
authored
Clean up some multiplication code (#388)
1 parent 7df375c commit 2c8b8b1

File tree

8 files changed

+342
-329
lines changed

8 files changed

+342
-329
lines changed

src/SparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Base: ReshapedArray, promote_op, setindex_shape_check, to_shape, tail,
1010
using Base.Order: Forward
1111
using LinearAlgebra
1212
using LinearAlgebra: AdjOrTrans, AdjointFactorization, TransposeFactorization, matprod,
13-
AbstractQ, AdjointQ, HessenbergQ, QRCompactWYQ, QRPackedQ, LQPackedQ,
13+
AbstractQ, AdjointQ, HessenbergQ, QRCompactWYQ, QRPackedQ, LQPackedQ, MulAddMul,
1414
UpperOrLowerTriangular
1515

1616

src/linalg.jl

Lines changed: 218 additions & 240 deletions
Large diffs are not rendered by default.

src/sparseconvert.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ _sparsem(A::AbstractSparseMatrix) = A
9595
_sparsem(A::AbstractSparseVector) = A
9696

9797
# Transpose/Adjoint of sparse vector (returning sparse matrix)
98-
function _sparsem(A::Union{Transpose{<:Any,<:AbstractSparseVector},Adjoint{<:Any,<:AbstractSparseVector}})
98+
function _sparsem(A::AdjOrTrans{<:Any,<:AbstractSparseVector})
9999
B = parent(A)
100100
n = length(B)
101101
Ti = eltype(nonzeroinds(B))
@@ -217,8 +217,7 @@ function _sparsem(A::AbstractTriangularSparse{Tv}) where Tv
217217
end
218218

219219
# 8 cases: (Transpose|Adjoint){Tv,[Unit](Upper|Lower)Triangular}
220-
function _sparsem(taA::Union{Transpose{Tv,<:AbstractTriangularSparse},
221-
Adjoint{Tv,<:AbstractTriangularSparse}}) where {Tv}
220+
function _sparsem(taA::AdjOrTrans{Tv,<:AbstractTriangularSparse}) where {Tv}
222221

223222
sA = taA.parent
224223
A = sA.data

src/sparsematrix.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ julia> nnz(A)
209209
"""
210210
nnz(S::AbstractSparseMatrixCSC) = Int(getcolptr(S)[size(S, 2) + 1]) - 1
211211
nnz(S::ReshapedArray{<:Any,1,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
212-
nnz(S::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
213-
nnz(S::Transpose{<:Any,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
212+
nnz(S::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
214213
nnz(S::UpperTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
215214
nnz(S::LowerTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
216215
nnz(S::SparseMatrixCSCView) = nnz1(S)

src/sparsevector.jl

Lines changed: 71 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import Base: sort!, findall, copy!
66
import LinearAlgebra: promote_to_array_type, promote_to_arrays_
7+
using LinearAlgebra: adj_or_trans
78

89
### The SparseVector
910

@@ -1768,23 +1769,52 @@ end
17681769

17691770
const _StridedOrTriangularMatrix{T} = Union{StridedMatrix{T}, LowerTriangular{T}, UnitLowerTriangular{T}, UpperTriangular{T}, UnitUpperTriangular{T}}
17701771

1772+
_fliptri(A::UpperTriangular) = LowerTriangular(parent(parent(A)))
1773+
_fliptri(A::UnitUpperTriangular) = UnitLowerTriangular(parent(parent(A)))
1774+
_fliptri(A::LowerTriangular) = UpperTriangular(parent(parent(A)))
1775+
_fliptri(A::UnitLowerTriangular) = UnitUpperTriangular(parent(parent(A)))
1776+
17711777
function (*)(A::_StridedOrTriangularMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
17721778
require_one_based_indexing(A, x)
17731779
m, n = size(A)
17741780
length(x) == n || throw(DimensionMismatch())
17751781
Ty = promote_op(matprod, eltype(A), eltype(x))
17761782
y = Vector{Ty}(undef, m)
1777-
mul!(y, A, x)
1783+
mul!(y, A, x, true, false)
17781784
end
17791785

1780-
function mul!(y::AbstractVector, A::_StridedOrTriangularMatrix, x::AbstractSparseVector, α::Number, β::Number)
1786+
function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
1787+
_add::MulAddMul = MulAddMul())
1788+
if tA == 'N'
1789+
_spmul!(y, A, x, _add.alpha, _add.beta)
1790+
elseif tA == 'T'
1791+
_At_or_Ac_mul_B!(transpose, y, A, x, _add.alpha, _add.beta)
1792+
elseif tA == 'C'
1793+
_At_or_Ac_mul_B!(adjoint, y, A, x, _add.alpha, _add.beta)
1794+
else
1795+
_spmul!(y, LinearAlgebra.wrap(A, tA), x, _add.alpha, _add.beta)
1796+
end
1797+
return y
1798+
end
1799+
function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector,
1800+
_add::MulAddMul = MulAddMul())
1801+
@assert tA == 'N'
1802+
Adata = parent(A)
1803+
if Adata isa Transpose
1804+
_At_or_Ac_mul_B!(transpose, y, _fliptri(A), x, _add.alpha, _add.beta)
1805+
elseif Adata isa Adjoint
1806+
_At_or_Ac_mul_B!(adjoint, y, _fliptri(A), x, _add.alpha, _add.beta)
1807+
else # Adata is plain
1808+
_spmul!(y, A, x, _add.alpha, _add.beta)
1809+
end
1810+
return y
1811+
end
1812+
function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector, α::Number, β::Number)
17811813
require_one_based_indexing(y, A, x)
17821814
m, n = size(A)
17831815
length(x) == n && length(y) == m || throw(DimensionMismatch())
17841816
m == 0 && return y
1785-
if β != one(β)
1786-
β == zero(β) ? fill!(y, zero(eltype(y))) : rmul!(y, β)
1787-
end
1817+
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
17881818
α == zero(α) && return y
17891819

17901820
xnzind = nonzeroinds(x)
@@ -1802,80 +1832,47 @@ function mul!(y::AbstractVector, A::_StridedOrTriangularMatrix, x::AbstractSpars
18021832
return y
18031833
end
18041834

1805-
# * and mul!(C, transpose(A), B)
1806-
1807-
function *(tA::Transpose{<:Any,<:_StridedOrTriangularMatrix{Ta}}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
1808-
require_one_based_indexing(tA, x)
1809-
m, n = size(tA)
1810-
length(x) == n || throw(DimensionMismatch())
1811-
Ty = promote_op(matprod, eltype(tA), eltype(x))
1812-
y = Vector{Ty}(undef, m)
1813-
mul!(y, tA, x)
1814-
end
1815-
1816-
function mul!(y::AbstractVector, tA::Transpose{<:Any,<:_StridedOrTriangularMatrix}, x::AbstractSparseVector, α::Number, β::Number)
1817-
require_one_based_indexing(y, tA, x)
1818-
m, n = size(tA)
1835+
function _At_or_Ac_mul_B!(tfun::Function,
1836+
y::AbstractVector, A::_StridedOrTriangularMatrix, x::AbstractSparseVector,
1837+
α::Number, β::Number)
1838+
require_one_based_indexing(y, A, x)
1839+
n, m = size(A)
18191840
length(x) == n && length(y) == m || throw(DimensionMismatch())
18201841
m == 0 && return y
1821-
if β != one(β)
1822-
β == zero(β) ? fill!(y, zero(eltype(y))) : rmul!(y, β)
1823-
end
1842+
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
18241843
α == zero(α) && return y
18251844

18261845
xnzind = nonzeroinds(x)
18271846
xnzval = nonzeros(x)
18281847
_nnz = length(xnzind)
18291848
_nnz == 0 && return y
18301849

1831-
A = tA.parent
18321850
Ty = promote_op(matprod, eltype(A), eltype(x))
18331851
@inbounds for j = 1:m
18341852
s = zero(Ty)
18351853
for i = 1:_nnz
1836-
s += transpose(A[xnzind[i], j]) * xnzval[i]
1854+
s += tfun(A[xnzind[i], j]) * xnzval[i]
18371855
end
18381856
y[j] += s * α
18391857
end
18401858
return y
18411859
end
18421860

1843-
# * and mul!(C, adjoint(A), B)
1844-
1845-
function *(adjA::Adjoint{<:Any,<:_StridedOrTriangularMatrix{Ta}}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
1846-
require_one_based_indexing(adjA, x)
1847-
m, n = size(adjA)
1861+
function *(A::AdjOrTrans{<:Any,<:StridedMatrix}, x::AbstractSparseVector)
1862+
require_one_based_indexing(A, x)
1863+
m, n = size(A)
18481864
length(x) == n || throw(DimensionMismatch())
1849-
Ty = promote_op(matprod, eltype(adjA), eltype(x))
1865+
Ty = promote_op(matprod, eltype(A), eltype(x))
18501866
y = Vector{Ty}(undef, m)
1851-
mul!(y, adjA, x)
1867+
mul!(y, A, x, true, false)
18521868
end
1853-
1854-
function mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:_StridedOrTriangularMatrix}, x::AbstractSparseVector, α::Number, β::Number)
1855-
require_one_based_indexing(y, adjA, x)
1856-
m, n = size(adjA)
1857-
length(x) == n && length(y) == m || throw(DimensionMismatch())
1858-
m == 0 && return y
1859-
if β != one(β)
1860-
β == zero(β) ? fill!(y, zero(eltype(y))) : rmul!(y, β)
1861-
end
1862-
α == zero(α) && return y
1863-
1864-
xnzind = nonzeroinds(x)
1865-
xnzval = nonzeros(x)
1866-
_nnz = length(xnzind)
1867-
_nnz == 0 && return y
1868-
1869-
A = adjA.parent
1869+
function *(A::LinearAlgebra.HermOrSym{<:Any,<:StridedMatrix}, x::AbstractSparseVector)
1870+
require_one_based_indexing(A, x)
1871+
m, n = size(A)
1872+
length(x) == n || throw(DimensionMismatch())
18701873
Ty = promote_op(matprod, eltype(A), eltype(x))
1871-
@inbounds for j = 1:m
1872-
s = zero(Ty)
1873-
for i = 1:_nnz
1874-
s += adjoint(A[xnzind[i], j]) * xnzval[i]
1875-
end
1876-
y[j] += s * α
1877-
end
1878-
return y
1874+
y = Vector{Ty}(undef, m)
1875+
mul!(y, A, x, true, false)
18791876
end
18801877

18811878

@@ -1906,15 +1903,26 @@ function densemv(A::AbstractSparseMatrixCSC, x::AbstractSparseVector; trans::Abs
19061903
end
19071904

19081905
# * and mul!
1906+
function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
1907+
_add::MulAddMul = MulAddMul())
1908+
if tA == 'N'
1909+
_spmul!(y, A, x, _add.alpha, _add.beta)
1910+
elseif tA == 'T'
1911+
_At_or_Ac_mul_B!((a,b) -> transpose(a) * b, y, A, x, _add.alpha, _add.beta)
1912+
elseif tA == 'C'
1913+
_At_or_Ac_mul_B!((a,b) -> adjoint(a) * b, y, A, x, _add.alpha, _add.beta)
1914+
else
1915+
LinearAlgebra._generic_matvecmul!(y, 'N', LinearAlgebra.wrap(A, tA), x, _add)
1916+
end
1917+
return y
1918+
end
19091919

1910-
function mul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSparseVector, α::Number, β::Number)
1920+
function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSparseVector, α::Number, β::Number)
19111921
require_one_based_indexing(y, A, x)
19121922
m, n = size(A)
19131923
length(x) == n && length(y) == m || throw(DimensionMismatch())
19141924
m == 0 && return y
1915-
if β != one(β)
1916-
β == zero(β) ? fill!(y, zero(eltype(y))) : rmul!(y, β)
1917-
end
1925+
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
19181926
α == zero(α) && return y
19191927

19201928
xnzind = nonzeroinds(x)
@@ -1936,23 +1944,14 @@ function mul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSparseVe
19361944
return y
19371945
end
19381946

1939-
# * and *(Transpose(A), B)
1940-
mul!(y::AbstractVector, tA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::AbstractSparseVector, α::Number, β::Number) =
1941-
_At_or_Ac_mul_B!((a,b) -> transpose(a) * b, y, tA.parent, x, α, β)
1942-
1943-
mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::AbstractSparseVector, α::Number, β::Number) =
1944-
_At_or_Ac_mul_B!((a,b) -> adjoint(a) * b, y, adjA.parent, x, α, β)
1945-
19461947
function _At_or_Ac_mul_B!(tfun::Function,
19471948
y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
19481949
α::Number, β::Number)
19491950
require_one_based_indexing(y, A, x)
19501951
m, n = size(A)
19511952
length(x) == m && length(y) == n || throw(DimensionMismatch())
19521953
n == 0 && return y
1953-
if β != one(β)
1954-
β == zero(β) ? fill!(y, zero(eltype(y))) : rmul!(y, β)
1955-
end
1954+
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
19561955
α == zero(α) && return y
19571956

19581957
xnzind = nonzeroinds(x)
@@ -1981,11 +1980,8 @@ function *(A::AbstractSparseMatrixCSC, x::AbstractSparseVector)
19811980
_dense2sparsevec(y, initcap)
19821981
end
19831982

1984-
*(tA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::AbstractSparseVector) =
1985-
_At_or_Ac_mul_B((a,b) -> transpose(a) * b, tA.parent, x, promote_op(matprod, eltype(tA), eltype(x)))
1986-
1987-
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::AbstractSparseVector) =
1988-
_At_or_Ac_mul_B((a,b) -> adjoint(a) * b, adjA.parent, x, promote_op(matprod, eltype(adjA), eltype(x)))
1983+
*(xA::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, x::AbstractSparseVector) =
1984+
_At_or_Ac_mul_B((a,b) -> adj_or_trans(xA)(a) * b, xA.parent, x, promote_op(matprod, eltype(xA), eltype(x)))
19891985

19901986
function _At_or_Ac_mul_B(tfun::Function, A::AbstractSparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX},
19911987
Tv = promote_op(matprod, TvA, TvX)) where {TvA,TiA,TvX,TiX}

test/linalg.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ end
190190
@testset "symmetric/Hermitian sparse multiply with $S($U)" for S in (Symmetric, Hermitian), U in (:U, :L), (A, B) in ((Areal,Breal), (Acomplex,Bcomplex))
191191
Asym = S(A, U)
192192
As = sparse(Asym) # takes most time
193-
@test which(mul!, (typeof(B), typeof(Asym), typeof(B))).module == SparseArrays
193+
# @test which(mul!, (typeof(B), typeof(Asym), typeof(B))).module == SparseArrays
194194
@test norm(Asym * B - As * B, Inf) <= eps() * n * p * 10
195195
end
196196
end
@@ -207,7 +207,7 @@ end
207207
@testset "symmetric/Hermitian sparseview multiply with $S($U)" for S in (Symmetric, Hermitian), U in (:U, :L), (A, B) in ((Areal,Breal), (Acomplex,Bcomplex))
208208
Asym = S(A, U)
209209
As = sparse(Asym) # takes most time
210-
@test which(mul!, (typeof(B), typeof(Asym), typeof(B))).module == SparseArrays
210+
# @test which(mul!, (typeof(B), typeof(Asym), typeof(B))).module == SparseArrays
211211
@test norm(Asym * B - As * B, Inf) <= eps() * n * p * 10
212212
end
213213
end
@@ -662,11 +662,14 @@ end
662662
@test Array(f*b) == f*Array(b)
663663
A = rand(2n, 2n)
664664
sA = view(A, 1:2:2n, 1:2:2n)
665-
@test Array(sA*b) Array(sA)*Array(b)
666-
@test Array(a*sA) Array(a)*Array(sA)
665+
@test Array((sA*b)::Matrix) Array(sA)*Array(b)
666+
@test Array((a*sA)::Matrix) Array(a)*Array(sA)
667+
@test Array((sA'b)::Matrix) Array(sA')*Array(b)
667668
c = sprandn(ComplexF32, n, n, q)
668-
@test Array(sA*c') Array(sA)*Array(c)'
669-
@test Array(c'*sA) Array(c)'*Array(sA)
669+
@test Array((sA*c')::Matrix) Array(sA)*Array(c)'
670+
@test Array((c'*sA)::Matrix) Array(c)'*Array(sA)
671+
@test Array((sA'c)::Matrix) Array(sA')*Array(c)
672+
@test Array((sA'c')::Matrix) Array(sA')*Array(c)'
670673
end
671674
end
672675

test/sparsematrix_ops.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,15 @@ dA = Array(sA)
175175
p28227 = sparse(Real[0 0.5])
176176

177177
for arr in (se33, sA, pA, p28227, spzeros(3, 3))
178+
farr = Array(arr)
178179
for f in (sum, prod, minimum, maximum)
179-
farr = Array(arr)
180180
@test f(arr) f(farr)
181181
@test f(arr, dims=1) f(farr, dims=1)
182182
@test f(arr, dims=2) f(farr, dims=2)
183183
@test f(arr, dims=(1, 2)) [f(farr)]
184184
@test isequal(f(arr, dims=3), f(farr, dims=3))
185185
end
186186
for f in (+, *, min, max)
187-
farr = Array(arr)
188187
@test mapreduce(identity, f, arr) mapreduce(identity, f, farr)
189188
@test mapreduce(x -> x + 1, f, arr) mapreduce(x -> x + 1, f, farr)
190189
end

test/sparsevector.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,22 @@ end
10981098
@test isa(y, Vector{T})
10991099
@test y *(adjoint(A), xf)
11001100
end
1101+
1102+
let A = randn(TA, 16, 16), x = sprand(Tx, 16, 0.7)
1103+
xf = Array(x)
1104+
for wrap in (M -> Symmetric(M, :U), M -> Symmetric(M, :L),
1105+
M -> Hermitian(M, :U), M -> Hermitian(M, :L))
1106+
for α in (0.0, 1.0, 2.0), β in (0.0, 0.5, 1.0)
1107+
y = rand(T, 16)
1108+
rr = α*wrap(A)*xf + β*y
1109+
@test mul!(y, wrap(A), x, α, β) === y
1110+
@test y rr
1111+
end
1112+
y = *(wrap(A), x)
1113+
@test isa(y, Vector{T})
1114+
@test y *(wrap(A), xf)
1115+
end
1116+
end
11011117
end
11021118
end
11031119
@testset "sparse A * sparse x -> dense y" begin
@@ -1129,6 +1145,29 @@ end
11291145
@test y *(transpose(Af), xf)
11301146
end
11311147

1148+
let A = sprandn(16, 16, 0.5), x = sprand(16, 0.7)
1149+
Af = Array(A)
1150+
xf = Array(x)
1151+
for wrap in (M -> Symmetric(M, :U), M -> Symmetric(M, :L),
1152+
M -> Hermitian(M, :U), M -> Hermitian(M, :L),
1153+
M -> UpperTriangular(M), M -> UnitUpperTriangular(M),
1154+
M -> LowerTriangular(M), M -> UnitLowerTriangular(M),
1155+
M -> UpperTriangular(transpose(M)), M -> UnitUpperTriangular(transpose(M)),
1156+
M -> LowerTriangular(transpose(M)), M -> UnitLowerTriangular(transpose(M)),
1157+
M -> UpperTriangular(adjoint(M)), M -> UnitUpperTriangular(adjoint(M)),
1158+
M -> LowerTriangular(adjoint(M)), M -> UnitLowerTriangular(adjoint(M)),
1159+
M -> UpperTriangular(Symmetric(M)))
1160+
for α in (0.0, 1.0, 2.0), β in (0.0, 0.5, 1.0)
1161+
y = rand(16)
1162+
rr = α*wrap(Af)*xf + β*y
1163+
@test mul!(y, wrap(A), x, α, β) === y
1164+
@test y rr
1165+
end
1166+
y = wrap(A) * x
1167+
@test y *(wrap(Af), xf)
1168+
end
1169+
end
1170+
11321171
let A = complex.(sprandn(7, 8, 0.5), sprandn(7, 8, 0.5)),
11331172
x = complex.(sprandn(8, 0.6), sprandn(8, 0.6)),
11341173
x2 = complex.(sprandn(7, 0.75), sprandn(7, 0.75))

0 commit comments

Comments
 (0)