|
1 | 1 | # This file is a part of Julia. License is MIT: https://julialang.org/license |
2 | 2 |
|
3 | 3 | using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular, |
4 | | - RealHermSymComplexHerm, checksquare, sym_uplo |
| 4 | + RealHermSymComplexHerm, checksquare, sym_uplo, wrap |
5 | 5 | using Random: rand! |
6 | 6 |
|
| 7 | +const tilebufsize = 10800 # Approximately 32k/3 |
| 8 | + |
7 | 9 | # In matrix-vector multiplication, the correct orientation of the vector is assumed. |
8 | 10 | const DenseMatrixUnion = Union{StridedMatrix, BitMatrix} |
9 | 11 | const DenseTriangular = UpperOrLowerTriangular{<:Any,<:DenseMatrixUnion} |
@@ -45,28 +47,28 @@ for op ∈ (:+, :-) |
45 | 47 | end |
46 | 48 | end |
47 | 49 |
|
48 | | -LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) = |
| 50 | +generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) = |
49 | 51 | spdensemul!(C, tA, tB, A, B, _add) |
50 | | -LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) = |
| 52 | +generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) = |
51 | 53 | spdensemul!(C, tA, tB, A, B, _add) |
52 | | -LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) = |
| 54 | +generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) = |
53 | 55 | spdensemul!(C, tA, 'N', A, B, _add) |
54 | 56 |
|
55 | 57 | Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add) |
56 | 58 | if tA == 'N' |
57 | | - _spmatmul!(C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta) |
| 59 | + _spmatmul!(C, A, wrap(B, tB), _add.alpha, _add.beta) |
58 | 60 | elseif tA == 'T' |
59 | | - _At_or_Ac_mul_B!(transpose, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta) |
| 61 | + _At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), _add.alpha, _add.beta) |
60 | 62 | elseif tA == 'C' |
61 | | - _At_or_Ac_mul_B!(adjoint, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta) |
| 63 | + _At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), _add.alpha, _add.beta) |
62 | 64 | elseif tA in ('S', 's', 'H', 'h') && tB == 'N' |
63 | 65 | rangefun = isuppercase(tA) ? nzrangeup : nzrangelo |
64 | 66 | diagop = tA in ('S', 's') ? identity : real |
65 | 67 | odiagop = tA in ('S', 's') ? transpose : adjoint |
66 | 68 | T = eltype(C) |
67 | 69 | _mul!(rangefun, diagop, odiagop, C, A, B, T(_add.alpha), T(_add.beta)) |
68 | 70 | else |
69 | | - LinearAlgebra._generic_matmatmul!(C, 'N', 'N', LinearAlgebra.wrap(A, tA), LinearAlgebra.wrap(B, tB), _add) |
| 71 | + _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) |
70 | 72 | end |
71 | 73 | return C |
72 | 74 | end |
@@ -114,7 +116,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) |
114 | 116 | C |
115 | 117 | end |
116 | 118 |
|
117 | | -Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul) |
| 119 | +Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul) |
118 | 120 | transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint |
119 | 121 | if tB == 'N' |
120 | 122 | _spmul!(C, transA(A), B, _add.alpha, _add.beta) |
@@ -316,6 +318,189 @@ function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer, |
316 | 318 | p >= 1 ? m*k : p > 0 ? Int(ceil(-expm1(log1p(-p) * n)*m*k)) : 0 # (1-(1-p)^n)*m*k |
317 | 319 | end |
318 | 320 |
|
| 321 | +Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2, |
| 322 | + B::SparseMatrixCSCUnion2, _add::MulAddMul) |
| 323 | + A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) |
| 324 | + B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) |
| 325 | + _generic_matmatmul!(C, tA, tB, A, B, _add) |
| 326 | +end |
| 327 | +function _generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::AbstractVecOrMat, |
| 328 | + B::AbstractVecOrMat, _add::MulAddMul) |
| 329 | + @assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') |
| 330 | + require_one_based_indexing(C, A, B) |
| 331 | + R = eltype(C) |
| 332 | + T = eltype(A) |
| 333 | + S = eltype(B) |
| 334 | + |
| 335 | + mA, nA = LinearAlgebra.lapack_size(tA, A) |
| 336 | + mB, nB = LinearAlgebra.lapack_size(tB, B) |
| 337 | + if mB != nA |
| 338 | + throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)")) |
| 339 | + end |
| 340 | + if size(C,1) != mA || size(C,2) != nB |
| 341 | + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)")) |
| 342 | + end |
| 343 | + |
| 344 | + if iszero(_add.alpha) || isempty(A) || isempty(B) |
| 345 | + return LinearAlgebra._rmul_or_fill!(C, _add.beta) |
| 346 | + end |
| 347 | + |
| 348 | + tile_size = 0 |
| 349 | + if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N') |
| 350 | + tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1))) |
| 351 | + end |
| 352 | + @inbounds begin |
| 353 | + if tile_size > 0 |
| 354 | + sz = (tile_size, tile_size) |
| 355 | + Atile = Array{T}(undef, sz) |
| 356 | + Btile = Array{S}(undef, sz) |
| 357 | + |
| 358 | + z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1]) |
| 359 | + z = convert(promote_type(typeof(z1), R), z1) |
| 360 | + |
| 361 | + if mA < tile_size && nA < tile_size && nB < tile_size |
| 362 | + copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA) |
| 363 | + copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB) |
| 364 | + for j = 1:nB |
| 365 | + boff = (j-1)*tile_size |
| 366 | + for i = 1:mA |
| 367 | + aoff = (i-1)*tile_size |
| 368 | + s = z |
| 369 | + for k = 1:nA |
| 370 | + s += Atile[aoff+k] * Btile[boff+k] |
| 371 | + end |
| 372 | + LinearAlgebra._modify!(_add, s, C, (i,j)) |
| 373 | + end |
| 374 | + end |
| 375 | + else |
| 376 | + Ctile = Array{R}(undef, sz) |
| 377 | + for jb = 1:tile_size:nB |
| 378 | + jlim = min(jb+tile_size-1,nB) |
| 379 | + jlen = jlim-jb+1 |
| 380 | + for ib = 1:tile_size:mA |
| 381 | + ilim = min(ib+tile_size-1,mA) |
| 382 | + ilen = ilim-ib+1 |
| 383 | + fill!(Ctile, z) |
| 384 | + for kb = 1:tile_size:nA |
| 385 | + klim = min(kb+tile_size-1,mB) |
| 386 | + klen = klim-kb+1 |
| 387 | + copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim) |
| 388 | + copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim) |
| 389 | + for j=1:jlen |
| 390 | + bcoff = (j-1)*tile_size |
| 391 | + for i = 1:ilen |
| 392 | + aoff = (i-1)*tile_size |
| 393 | + s = z |
| 394 | + for k = 1:klen |
| 395 | + s += Atile[aoff+k] * Btile[bcoff+k] |
| 396 | + end |
| 397 | + Ctile[bcoff+i] += s |
| 398 | + end |
| 399 | + end |
| 400 | + end |
| 401 | + if isone(_add.alpha) && iszero(_add.beta) |
| 402 | + copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen) |
| 403 | + else |
| 404 | + C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim]) |
| 405 | + end |
| 406 | + end |
| 407 | + end |
| 408 | + end |
| 409 | + else |
| 410 | + # Multiplication for non-plain-data uses the naive algorithm |
| 411 | + if tA == 'N' |
| 412 | + if tB == 'N' |
| 413 | + for i = 1:mA, j = 1:nB |
| 414 | + z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) |
| 415 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 416 | + for k = 1:nA |
| 417 | + Ctmp += A[i, k]*B[k, j] |
| 418 | + end |
| 419 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 420 | + end |
| 421 | + elseif tB == 'T' |
| 422 | + for i = 1:mA, j = 1:nB |
| 423 | + z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1])) |
| 424 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 425 | + for k = 1:nA |
| 426 | + Ctmp += A[i, k] * transpose(B[j, k]) |
| 427 | + end |
| 428 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 429 | + end |
| 430 | + else |
| 431 | + for i = 1:mA, j = 1:nB |
| 432 | + z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]') |
| 433 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 434 | + for k = 1:nA |
| 435 | + Ctmp += A[i, k]*B[j, k]' |
| 436 | + end |
| 437 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 438 | + end |
| 439 | + end |
| 440 | + elseif tA == 'T' |
| 441 | + if tB == 'N' |
| 442 | + for i = 1:mA, j = 1:nB |
| 443 | + z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j]) |
| 444 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 445 | + for k = 1:nA |
| 446 | + Ctmp += transpose(A[k, i]) * B[k, j] |
| 447 | + end |
| 448 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 449 | + end |
| 450 | + elseif tB == 'T' |
| 451 | + for i = 1:mA, j = 1:nB |
| 452 | + z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1])) |
| 453 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 454 | + for k = 1:nA |
| 455 | + Ctmp += transpose(A[k, i]) * transpose(B[j, k]) |
| 456 | + end |
| 457 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 458 | + end |
| 459 | + else |
| 460 | + for i = 1:mA, j = 1:nB |
| 461 | + z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]') |
| 462 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 463 | + for k = 1:nA |
| 464 | + Ctmp += transpose(A[k, i]) * adjoint(B[j, k]) |
| 465 | + end |
| 466 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 467 | + end |
| 468 | + end |
| 469 | + else |
| 470 | + if tB == 'N' |
| 471 | + for i = 1:mA, j = 1:nB |
| 472 | + z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j]) |
| 473 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 474 | + for k = 1:nA |
| 475 | + Ctmp += A[k, i]'B[k, j] |
| 476 | + end |
| 477 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 478 | + end |
| 479 | + elseif tB == 'T' |
| 480 | + for i = 1:mA, j = 1:nB |
| 481 | + z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1])) |
| 482 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 483 | + for k = 1:nA |
| 484 | + Ctmp += adjoint(A[k, i]) * transpose(B[j, k]) |
| 485 | + end |
| 486 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 487 | + end |
| 488 | + else |
| 489 | + for i = 1:mA, j = 1:nB |
| 490 | + z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]') |
| 491 | + Ctmp = convert(promote_type(R, typeof(z2)), z2) |
| 492 | + for k = 1:nA |
| 493 | + Ctmp += A[k, i]'B[j, k]' |
| 494 | + end |
| 495 | + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) |
| 496 | + end |
| 497 | + end |
| 498 | + end |
| 499 | + end |
| 500 | + end # @inbounds |
| 501 | + C |
| 502 | +end |
| 503 | + |
319 | 504 | if VERSION < v"1.10.0-DEV.299" |
320 | 505 | top_set_bit(x::Base.BitInteger) = 8 * sizeof(x) - leading_zeros(x) |
321 | 506 | else |
|
0 commit comments