7070 alpha:: Number , beta:: Number ) =
7171 generic_matvecmul! (y, adj_or_trans_char (A), _parent (A), x, MulAddMul (alpha, beta))
7272# BLAS cases
73- @inline mul! (y:: StridedVector{T} , A:: StridedMaybeAdjOrTransVecOrMat{T} , x:: StridedVector{T} ,
74- alpha:: Number , beta:: Number ) where {T<: BlasFloat } =
75- gemv! (y, adj_or_trans_char (A), _parent (A), x, alpha, beta)
76- # catch the real adjoint case and rewrap to transpose
77- @inline mul! (y:: StridedVector{T} , adjA:: Adjoint{<:Any,<:StridedVecOrMat{T}} , x:: StridedVector{T} ,
78- alpha:: Number , beta:: Number ) where {T<: BlasReal } =
79- mul! (y, transpose (adjA. parent), x, alpha, beta)
73+ # equal eltypes
74+ @inline generic_matvecmul! (y:: StridedVector{T} , tA, A:: StridedVecOrMat{T} , x:: StridedVector{T} ,
75+ _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
76+ gemv! (y, tA, _parent (A), x, _add. alpha, _add. beta)
77+ # Real (possibly transposed) matrix times complex vector.
78+ # Multiply the matrix with the real and imaginary parts separately
79+ @inline generic_matvecmul! (y:: StridedVector{Complex{T}} , tA, A:: StridedVecOrMat{T} , x:: StridedVector{Complex{T}} ,
80+ _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal } =
81+ gemv! (y, tA, _parent (A), x, _add. alpha, _add. beta)
8082# Complex matrix times real vector.
8183# Reinterpret the matrix as a real matrix and do real matvec computation.
82- @inline mul! (y:: StridedVector{Complex{T}} , A:: StridedVecOrMat{Complex{T}} , x:: StridedVector{T} ,
83- alpha:: Number , beta:: Number ) where {T<: BlasReal } =
84- gemv! (y, ' N' , A, x, alpha, beta)
85- # Real matrix times complex vector.
86- # Multiply the matrix with the real and imaginary parts separately
87- @inline mul! (y:: StridedVector{Complex{T}} , A:: StridedMaybeAdjOrTransMat{T} , x:: StridedVector{Complex{T}} ,
88- alpha:: Number , beta:: Number ) where {T<: BlasReal } =
89- gemv! (y, A isa StridedArray ? ' N' : ' T' , _parent (A), x, alpha, beta)
84+ # works only in cooperation with BLAS when A is untransposed (tA == 'N')
85+ # but that check is included in gemv! anyway
86+ @inline generic_matvecmul! (y:: StridedVector{Complex{T}} , tA, A:: StridedVecOrMat{Complex{T}} , x:: StridedVector{T} ,
87+ _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal } =
88+ gemv! (y, tA, _parent (A), x, _add. alpha, _add. beta)
9089
9190# Vector-Matrix multiplication
9291(* )(x:: AdjointAbsVec , A:: AbstractMatrix ) = (A' * x' )'
@@ -341,66 +340,26 @@ julia> lmul!(F.Q, B)
341340"""
342341lmul! (A, B)
343342
344- # generic case
345- @inline mul! (C:: StridedMatrix{T} , A:: StridedMaybeAdjOrTransVecOrMat{T} , B:: StridedMaybeAdjOrTransVecOrMat{T} ,
346- alpha:: Number , beta:: Number ) where {T<: BlasFloat } =
347- gemm_wrapper! (C, adj_or_trans_char (A), adj_or_trans_char (B), _parent (A), _parent (B), MulAddMul (alpha, beta))
348-
349- # AtB & ABt (including B === A)
350- @inline function mul! (C:: StridedMatrix{T} , tA:: Transpose{<:Any,<:StridedVecOrMat{T}} , B:: StridedVecOrMat{T} ,
351- alpha:: Number , beta:: Number ) where {T<: BlasFloat }
352- A = tA. parent
353- if A === B
354- return syrk_wrapper! (C, ' T' , A, MulAddMul (alpha, beta))
355- else
356- return gemm_wrapper! (C, ' T' , ' N' , A, B, MulAddMul (alpha, beta))
357- end
358- end
359- @inline function mul! (C:: StridedMatrix{T} , A:: StridedVecOrMat{T} , tB:: Transpose{<:Any,<:StridedVecOrMat{T}} ,
360- alpha:: Number , beta:: Number ) where {T<: BlasFloat }
361- B = tB. parent
362- if A === B
363- return syrk_wrapper! (C, ' N' , A, MulAddMul (alpha, beta))
364- else
365- return gemm_wrapper! (C, ' N' , ' T' , A, B, MulAddMul (alpha, beta))
366- end
367- end
368- # real adjoint cases, also needed for disambiguation
369- @inline mul! (C:: StridedMatrix{T} , A:: StridedVecOrMat{T} , adjB:: Adjoint{<:Any,<:StridedVecOrMat{T}} ,
370- alpha:: Number , beta:: Number ) where {T<: BlasReal } =
371- mul! (C, A, transpose (adjB. parent), alpha, beta)
372- @inline mul! (C:: StridedMatrix{T} , adjA:: Adjoint{<:Any,<:StridedVecOrMat{T}} , B:: StridedVecOrMat{T} ,
373- alpha:: Real , beta:: Real ) where {T<: BlasReal } =
374- mul! (C, transpose (adjA. parent), B, alpha, beta)
375-
376- # AcB & ABc (including B === A)
377- @inline function mul! (C:: StridedMatrix{T} , adjA:: Adjoint{<:Any,<:StridedVecOrMat{T}} , B:: StridedVecOrMat{T} ,
378- alpha:: Number , beta:: Number ) where {T<: BlasComplex }
379- A = adjA. parent
380- if A === B
381- return herk_wrapper! (C, ' C' , A, MulAddMul (alpha, beta))
343+ @inline function generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
344+ _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat }
345+ if tA == ' T' && tB == ' N' && A === B
346+ return syrk_wrapper! (C, ' T' , A, _add)
347+ elseif tA == ' N' && tB == ' T' && A === B
348+ return syrk_wrapper! (C, ' N' , A, _add)
349+ elseif tA == ' C' && tB == ' N' && A === B
350+ return herk_wrapper! (C, ' C' , A, _add)
351+ elseif tA == ' N' && tB == ' C' && A === B
352+ return herk_wrapper! (C, ' N' , A, _add)
382353 else
383- return gemm_wrapper! (C, ' C' , ' N' , A, B, MulAddMul (alpha, beta))
384- end
385- end
386- @inline function mul! (C:: StridedMatrix{T} , A:: StridedVecOrMat{T} , adjB:: Adjoint{<:Any,<:StridedVecOrMat{T}} ,
387- alpha:: Number , beta:: Number ) where {T<: BlasComplex }
388- B = adjB. parent
389- if A === B
390- return herk_wrapper! (C, ' N' , A, MulAddMul (alpha, beta))
391- else
392- return gemm_wrapper! (C, ' N' , ' C' , A, B, MulAddMul (alpha, beta))
354+ return gemm_wrapper! (C, tA, tB, A, B, _add)
393355 end
394356end
395357
396358# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
397- @inline mul! (C:: StridedMatrix{Complex{T}} , A:: StridedMaybeAdjOrTransVecOrMat{Complex{T}} , B:: StridedMaybeAdjOrTransVecOrMat{T} ,
398- alpha:: Number , beta:: Number ) where {T<: BlasReal } =
399- gemm_wrapper! (C, adj_or_trans_char (A), adj_or_trans_char (B), _parent (A), _parent (B), MulAddMul (alpha, beta))
400- # catch the real adjoint case and interpret it as a transpose
401- @inline mul! (C:: StridedMatrix{Complex{T}} , A:: StridedVecOrMat{Complex{T}} , adjB:: Adjoint{<:Any,<:StridedVecOrMat{T}} ,
402- alpha:: Number , beta:: Number ) where {T<: BlasReal } =
403- mul! (C, A, transpose (adjB. parent), alpha, beta)
359+ @inline function generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
360+ _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal }
361+ gemm_wrapper! (C, tA, tB, A, B, _add)
362+ end
404363
405364
406365# Supporting functions for matrix multiplication
@@ -438,7 +397,7 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
438397 ! iszero (stride (x, 1 )) # We only check input's stride here.
439398 return BLAS. gemv! (tA, alpha, A, x, beta, y)
440399 else
441- return generic_matvecmul ! (y, tA, A, x, MulAddMul (α, β))
400+ return _generic_matvecmul ! (y, tA, A, x, MulAddMul (α, β))
442401 end
443402end
444403
@@ -459,7 +418,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
459418 BLAS. gemv! (tA, alpha, reinterpret (T, A), x, beta, reinterpret (T, y))
460419 return y
461420 else
462- return generic_matvecmul ! (y, tA, A, x, MulAddMul (α, β))
421+ return _generic_matvecmul ! (y, tA, A, x, MulAddMul (α, β))
463422 end
464423end
465424
@@ -482,7 +441,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
482441 BLAS. gemv! (tA, alpha, A, xfl[2 , :], beta, yfl[2 , :])
483442 return y
484443 else
485- return generic_matvecmul ! (y, tA, A, x, MulAddMul (α, β))
444+ return _generic_matvecmul ! (y, tA, A, x, MulAddMul (α, β))
486445 end
487446end
488447
@@ -609,7 +568,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
609568 stride (C, 2 ) >= size (C, 1 ))
610569 return BLAS. gemm! (tA, tB, alpha, A, B, beta, C)
611570 end
612- generic_matmatmul ! (C, tA, tB, A, B, _add)
571+ _generic_matmatmul ! (C, tA, tB, A, B, _add)
613572end
614573
615574function gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
@@ -652,7 +611,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
652611 BLAS. gemm! (tA, tB, alpha, reinterpret (T, A), B, beta, reinterpret (T, C))
653612 return C
654613 end
655- generic_matmatmul ! (C, tA, tB, A, B, _add)
614+ _generic_matmatmul ! (C, tA, tB, A, B, _add)
656615end
657616
658617# blas.jl defines matmul for floats; other integer and mixed precision
686645# NOTE: the generic version is also called as fallback for
687646# strides != 1 cases
688647
689- function generic_matvecmul! (C:: AbstractVector{R} , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
690- _add:: MulAddMul = MulAddMul ()) where R
648+ generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
649+ _add:: MulAddMul = MulAddMul ()) =
650+ _generic_matvecmul! (C, tA, A, B, _add)
651+
652+ function _generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
653+ _add:: MulAddMul = MulAddMul ())
691654 require_one_based_indexing (C, A, B)
692655 mB = length (B)
693656 mA, nA = lapack_size (tA, A)
0 commit comments