Skip to content

Commit cb602d7

Browse files
authored
Add generic matmatmul for inplace sparse x sparse (#486)
1 parent 95575c0 commit cb602d7

File tree

4 files changed

+217
-15
lines changed

4 files changed

+217
-15
lines changed

src/SparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds
1818
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
1919
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded,
2020
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu,
21-
matprod_dest
21+
matprod_dest, generic_matvecmul!, generic_matmatmul!
2222

2323
import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex,
2424
conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin,

src/linalg.jl

Lines changed: 194 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular,
4-
RealHermSymComplexHerm, checksquare, sym_uplo
4+
RealHermSymComplexHerm, checksquare, sym_uplo, wrap
55
using Random: rand!
66

7+
const tilebufsize = 10800 # Approximately 32k/3
8+
79
# In matrix-vector multiplication, the correct orientation of the vector is assumed.
810
const DenseMatrixUnion = Union{StridedMatrix, BitMatrix}
911
const DenseTriangular = UpperOrLowerTriangular{<:Any,<:DenseMatrixUnion}
@@ -45,28 +47,28 @@ for op ∈ (:+, :-)
4547
end
4648
end
4749

48-
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
50+
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
4951
spdensemul!(C, tA, tB, A, B, _add)
50-
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
52+
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
5153
spdensemul!(C, tA, tB, A, B, _add)
52-
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
54+
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
5355
spdensemul!(C, tA, 'N', A, B, _add)
5456

5557
Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
5658
if tA == 'N'
57-
_spmatmul!(C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta)
59+
_spmatmul!(C, A, wrap(B, tB), _add.alpha, _add.beta)
5860
elseif tA == 'T'
59-
_At_or_Ac_mul_B!(transpose, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta)
61+
_At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), _add.alpha, _add.beta)
6062
elseif tA == 'C'
61-
_At_or_Ac_mul_B!(adjoint, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta)
63+
_At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), _add.alpha, _add.beta)
6264
elseif tA in ('S', 's', 'H', 'h') && tB == 'N'
6365
rangefun = isuppercase(tA) ? nzrangeup : nzrangelo
6466
diagop = tA in ('S', 's') ? identity : real
6567
odiagop = tA in ('S', 's') ? transpose : adjoint
6668
T = eltype(C)
6769
_mul!(rangefun, diagop, odiagop, C, A, B, T(_add.alpha), T(_add.beta))
6870
else
69-
LinearAlgebra._generic_matmatmul!(C, 'N', 'N', LinearAlgebra.wrap(A, tA), LinearAlgebra.wrap(B, tB), _add)
71+
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
7072
end
7173
return C
7274
end
@@ -114,7 +116,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
114116
C
115117
end
116118

117-
Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
119+
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
118120
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
119121
if tB == 'N'
120122
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
@@ -316,6 +318,189 @@ function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer,
316318
p >= 1 ? m*k : p > 0 ? Int(ceil(-expm1(log1p(-p) * n)*m*k)) : 0 # (1-(1-p)^n)*m*k
317319
end
318320

321+
Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2,
322+
B::SparseMatrixCSCUnion2, _add::MulAddMul)
323+
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
324+
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
325+
_generic_matmatmul!(C, tA, tB, A, B, _add)
326+
end
327+
function _generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::AbstractVecOrMat,
328+
B::AbstractVecOrMat, _add::MulAddMul)
329+
@assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')
330+
require_one_based_indexing(C, A, B)
331+
R = eltype(C)
332+
T = eltype(A)
333+
S = eltype(B)
334+
335+
mA, nA = LinearAlgebra.lapack_size(tA, A)
336+
mB, nB = LinearAlgebra.lapack_size(tB, B)
337+
if mB != nA
338+
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)"))
339+
end
340+
if size(C,1) != mA || size(C,2) != nB
341+
throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)"))
342+
end
343+
344+
if iszero(_add.alpha) || isempty(A) || isempty(B)
345+
return LinearAlgebra._rmul_or_fill!(C, _add.beta)
346+
end
347+
348+
tile_size = 0
349+
if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N')
350+
tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1)))
351+
end
352+
@inbounds begin
353+
if tile_size > 0
354+
sz = (tile_size, tile_size)
355+
Atile = Array{T}(undef, sz)
356+
Btile = Array{S}(undef, sz)
357+
358+
z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1])
359+
z = convert(promote_type(typeof(z1), R), z1)
360+
361+
if mA < tile_size && nA < tile_size && nB < tile_size
362+
copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA)
363+
copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB)
364+
for j = 1:nB
365+
boff = (j-1)*tile_size
366+
for i = 1:mA
367+
aoff = (i-1)*tile_size
368+
s = z
369+
for k = 1:nA
370+
s += Atile[aoff+k] * Btile[boff+k]
371+
end
372+
LinearAlgebra._modify!(_add, s, C, (i,j))
373+
end
374+
end
375+
else
376+
Ctile = Array{R}(undef, sz)
377+
for jb = 1:tile_size:nB
378+
jlim = min(jb+tile_size-1,nB)
379+
jlen = jlim-jb+1
380+
for ib = 1:tile_size:mA
381+
ilim = min(ib+tile_size-1,mA)
382+
ilen = ilim-ib+1
383+
fill!(Ctile, z)
384+
for kb = 1:tile_size:nA
385+
klim = min(kb+tile_size-1,mB)
386+
klen = klim-kb+1
387+
copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim)
388+
copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim)
389+
for j=1:jlen
390+
bcoff = (j-1)*tile_size
391+
for i = 1:ilen
392+
aoff = (i-1)*tile_size
393+
s = z
394+
for k = 1:klen
395+
s += Atile[aoff+k] * Btile[bcoff+k]
396+
end
397+
Ctile[bcoff+i] += s
398+
end
399+
end
400+
end
401+
if isone(_add.alpha) && iszero(_add.beta)
402+
copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen)
403+
else
404+
C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim])
405+
end
406+
end
407+
end
408+
end
409+
else
410+
# Multiplication for non-plain-data uses the naive algorithm
411+
if tA == 'N'
412+
if tB == 'N'
413+
for i = 1:mA, j = 1:nB
414+
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
415+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
416+
for k = 1:nA
417+
Ctmp += A[i, k]*B[k, j]
418+
end
419+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
420+
end
421+
elseif tB == 'T'
422+
for i = 1:mA, j = 1:nB
423+
z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1]))
424+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
425+
for k = 1:nA
426+
Ctmp += A[i, k] * transpose(B[j, k])
427+
end
428+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
429+
end
430+
else
431+
for i = 1:mA, j = 1:nB
432+
z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]')
433+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
434+
for k = 1:nA
435+
Ctmp += A[i, k]*B[j, k]'
436+
end
437+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
438+
end
439+
end
440+
elseif tA == 'T'
441+
if tB == 'N'
442+
for i = 1:mA, j = 1:nB
443+
z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j])
444+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
445+
for k = 1:nA
446+
Ctmp += transpose(A[k, i]) * B[k, j]
447+
end
448+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
449+
end
450+
elseif tB == 'T'
451+
for i = 1:mA, j = 1:nB
452+
z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1]))
453+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
454+
for k = 1:nA
455+
Ctmp += transpose(A[k, i]) * transpose(B[j, k])
456+
end
457+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
458+
end
459+
else
460+
for i = 1:mA, j = 1:nB
461+
z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]')
462+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
463+
for k = 1:nA
464+
Ctmp += transpose(A[k, i]) * adjoint(B[j, k])
465+
end
466+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
467+
end
468+
end
469+
else
470+
if tB == 'N'
471+
for i = 1:mA, j = 1:nB
472+
z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j])
473+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
474+
for k = 1:nA
475+
Ctmp += A[k, i]'B[k, j]
476+
end
477+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
478+
end
479+
elseif tB == 'T'
480+
for i = 1:mA, j = 1:nB
481+
z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1]))
482+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
483+
for k = 1:nA
484+
Ctmp += adjoint(A[k, i]) * transpose(B[j, k])
485+
end
486+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
487+
end
488+
else
489+
for i = 1:mA, j = 1:nB
490+
z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]')
491+
Ctmp = convert(promote_type(R, typeof(z2)), z2)
492+
for k = 1:nA
493+
Ctmp += A[k, i]'B[j, k]'
494+
end
495+
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
496+
end
497+
end
498+
end
499+
end
500+
end # @inbounds
501+
C
502+
end
503+
319504
if VERSION < v"1.10.0-DEV.299"
320505
top_set_bit(x::Base.BitInteger) = 8 * sizeof(x) - leading_zeros(x)
321506
else

src/sparsevector.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,7 @@ function (*)(A::_StridedOrTriangularMatrix{Ta}, x::AbstractSparseVector{Tx}) whe
18581858
mul!(y, A, x)
18591859
end
18601860

1861-
Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
1861+
Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
18621862
_add::MulAddMul = MulAddMul())
18631863
if tA == 'N'
18641864
_spmul!(y, A, x, _add.alpha, _add.beta)
@@ -1867,11 +1867,11 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::Abstrac
18671867
elseif tA == 'C'
18681868
_At_or_Ac_mul_B!(adjoint, y, A, x, _add.alpha, _add.beta)
18691869
else
1870-
_spmul!(y, LinearAlgebra.wrap(A, tA), x, _add.alpha, _add.beta)
1870+
_spmul!(y, wrap(A, tA), x, _add.alpha, _add.beta)
18711871
end
18721872
return y
18731873
end
1874-
function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector,
1874+
function generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector,
18751875
_add::MulAddMul = MulAddMul())
18761876
@assert tA == 'N'
18771877
Adata = parent(A)
@@ -1989,7 +1989,7 @@ function densemv(A::AbstractSparseMatrixCSC, x::AbstractSparseVector; trans::Abs
19891989
end
19901990

19911991
# * and mul!
1992-
Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
1992+
Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
19931993
_add::MulAddMul = MulAddMul())
19941994
if tA == 'N'
19951995
_spmul!(y, A, x, _add.alpha, _add.beta)
@@ -1998,7 +1998,7 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::Abstrac
19981998
elseif tA == 'C'
19991999
_At_or_Ac_mul_B!((a,b) -> adjoint(a) * b, y, A, x, _add.alpha, _add.beta)
20002000
else
2001-
LinearAlgebra._generic_matvecmul!(y, 'N', LinearAlgebra.wrap(A, tA), x, _add)
2001+
LinearAlgebra._generic_matvecmul!(y, 'N', wrap(A, tA), x, _add)
20022002
end
20032003
return y
20042004
end

test/linalg.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,23 @@ end
228228
end
229229
end
230230

231+
@testset "in-place sparse-sparse mul!" begin
232+
for n in (20, 30)
233+
sA = sprandn(ComplexF64, n, n, 0.1); A = Array(sA)
234+
sB = sprandn(ComplexF64, n, n, 0.1); B = Array(sB)
235+
sC = sprandn(ComplexF64, n, n, 0.1); C = Array(sC)
236+
a = randn(ComplexF64); b = randn(ComplexF64)
237+
for (sA, A) in ((sA, A), (view(sA, :, 1:1:n), A[:,1:1:n]))
238+
for trA in (identity, adjoint, transpose), trB in (identity, adjoint, transpose)
239+
@test mul!(copy(sC), trA(sA), trB(sB)) trA(A) * trB(B)
240+
for α in (true, false, a), β in (true, false, b)
241+
@test mul!(copy(sC), trA(sA), trB(sB), α, β) C*β + trA(A) * trB(B) * α
242+
end
243+
end
244+
end
245+
end
246+
end
247+
231248
@testset "UniformScaling" begin
232249
local A = sprandn(10, 10, 0.5)
233250
MA = Array(A)

0 commit comments

Comments
 (0)