Skip to content

Commit 15d7bd8

Browse files
authored
Simplify mul! dispatch (#49806)
1 parent fbbe9ed commit 15d7bd8

File tree

2 files changed

+41
-77
lines changed

2 files changed

+41
-77
lines changed

stdlib/LinearAlgebra/src/adjtrans.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ inplace_adj_or_trans(::Type{<:Transpose}) = transpose!
9797
adj_or_trans_char(::T) where {T<:AbstractArray} = adj_or_trans_char(T)
9898
adj_or_trans_char(::Type{<:AbstractArray}) = 'N'
9999
adj_or_trans_char(::Type{<:Adjoint}) = 'C'
100+
adj_or_trans_char(::Type{<:Adjoint{<:Real}}) = 'T'
100101
adj_or_trans_char(::Type{<:Transpose}) = 'T'
101102

102103
Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent)

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 40 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,22 @@ end
7070
alpha::Number, beta::Number) =
7171
generic_matvecmul!(y, adj_or_trans_char(A), _parent(A), x, MulAddMul(alpha, beta))
7272
# BLAS cases
73-
@inline mul!(y::StridedVector{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, x::StridedVector{T},
74-
alpha::Number, beta::Number) where {T<:BlasFloat} =
75-
gemv!(y, adj_or_trans_char(A), _parent(A), x, alpha, beta)
76-
# catch the real adjoint case and rewrap to transpose
77-
@inline mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
78-
alpha::Number, beta::Number) where {T<:BlasReal} =
79-
mul!(y, transpose(adjA.parent), x, alpha, beta)
73+
# equal eltypes
74+
@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
75+
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat} =
76+
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)
77+
# Real (possibly transposed) matrix times complex vector.
78+
# Multiply the matrix with the real and imaginary parts separately
79+
@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
80+
_add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
81+
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)
8082
# Complex matrix times real vector.
8183
# Reinterpret the matrix as a real matrix and do real matvec computation.
82-
@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
83-
alpha::Number, beta::Number) where {T<:BlasReal} =
84-
gemv!(y, 'N', A, x, alpha, beta)
85-
# Real matrix times complex vector.
86-
# Multiply the matrix with the real and imaginary parts separately
87-
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
88-
alpha::Number, beta::Number) where {T<:BlasReal} =
89-
gemv!(y, A isa StridedArray ? 'N' : 'T', _parent(A), x, alpha, beta)
84+
# works only in cooperation with BLAS when A is untransposed (tA == 'N')
85+
# but that check is included in gemv! anyway
86+
@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
87+
_add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
88+
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)
9089

9190
# Vector-Matrix multiplication
9291
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
@@ -341,66 +340,26 @@ julia> lmul!(F.Q, B)
341340
"""
342341
lmul!(A, B)
343342

344-
# generic case
345-
@inline mul!(C::StridedMatrix{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, B::StridedMaybeAdjOrTransVecOrMat{T},
346-
alpha::Number, beta::Number) where {T<:BlasFloat} =
347-
gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta))
348-
349-
# AtB & ABt (including B === A)
350-
@inline function mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
351-
alpha::Number, beta::Number) where {T<:BlasFloat}
352-
A = tA.parent
353-
if A === B
354-
return syrk_wrapper!(C, 'T', A, MulAddMul(alpha, beta))
355-
else
356-
return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta))
357-
end
358-
end
359-
@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
360-
alpha::Number, beta::Number) where {T<:BlasFloat}
361-
B = tB.parent
362-
if A === B
363-
return syrk_wrapper!(C, 'N', A, MulAddMul(alpha, beta))
364-
else
365-
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
366-
end
367-
end
368-
# real adjoint cases, also needed for disambiguation
369-
@inline mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
370-
alpha::Number, beta::Number) where {T<:BlasReal} =
371-
mul!(C, A, transpose(adjB.parent), alpha, beta)
372-
@inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
373-
alpha::Real, beta::Real) where {T<:BlasReal} =
374-
mul!(C, transpose(adjA.parent), B, alpha, beta)
375-
376-
# AcB & ABc (including B === A)
377-
@inline function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
378-
alpha::Number, beta::Number) where {T<:BlasComplex}
379-
A = adjA.parent
380-
if A === B
381-
return herk_wrapper!(C, 'C', A, MulAddMul(alpha, beta))
343+
@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
344+
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
345+
if tA == 'T' && tB == 'N' && A === B
346+
return syrk_wrapper!(C, 'T', A, _add)
347+
elseif tA == 'N' && tB == 'T' && A === B
348+
return syrk_wrapper!(C, 'N', A, _add)
349+
elseif tA == 'C' && tB == 'N' && A === B
350+
return herk_wrapper!(C, 'C', A, _add)
351+
elseif tA == 'N' && tB == 'C' && A === B
352+
return herk_wrapper!(C, 'N', A, _add)
382353
else
383-
return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta))
384-
end
385-
end
386-
@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
387-
alpha::Number, beta::Number) where {T<:BlasComplex}
388-
B = adjB.parent
389-
if A === B
390-
return herk_wrapper!(C, 'N', A, MulAddMul(alpha, beta))
391-
else
392-
return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta))
354+
return gemm_wrapper!(C, tA, tB, A, B, _add)
393355
end
394356
end
395357

396358
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
397-
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedMaybeAdjOrTransVecOrMat{Complex{T}}, B::StridedMaybeAdjOrTransVecOrMat{T},
398-
alpha::Number, beta::Number) where {T<:BlasReal} =
399-
gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta))
400-
# catch the real adjoint case and interpret it as a transpose
401-
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
402-
alpha::Number, beta::Number) where {T<:BlasReal} =
403-
mul!(C, A, transpose(adjB.parent), alpha, beta)
359+
@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
360+
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
361+
gemm_wrapper!(C, tA, tB, A, B, _add)
362+
end
404363

405364

406365
# Supporting functions for matrix multiplication
@@ -438,7 +397,7 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
438397
!iszero(stride(x, 1)) # We only check input's stride here.
439398
return BLAS.gemv!(tA, alpha, A, x, beta, y)
440399
else
441-
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
400+
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
442401
end
443402
end
444403

@@ -459,7 +418,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
459418
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
460419
return y
461420
else
462-
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
421+
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
463422
end
464423
end
465424

@@ -482,7 +441,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
482441
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
483442
return y
484443
else
485-
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
444+
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
486445
end
487446
end
488447

@@ -609,7 +568,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
609568
stride(C, 2) >= size(C, 1))
610569
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
611570
end
612-
generic_matmatmul!(C, tA, tB, A, B, _add)
571+
_generic_matmatmul!(C, tA, tB, A, B, _add)
613572
end
614573

615574
function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
@@ -652,7 +611,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
652611
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
653612
return C
654613
end
655-
generic_matmatmul!(C, tA, tB, A, B, _add)
614+
_generic_matmatmul!(C, tA, tB, A, B, _add)
656615
end
657616

658617
# blas.jl defines matmul for floats; other integer and mixed precision
@@ -686,8 +645,12 @@ end
686645
# NOTE: the generic version is also called as fallback for
687646
# strides != 1 cases
688647

689-
function generic_matvecmul!(C::AbstractVector{R}, tA, A::AbstractVecOrMat, B::AbstractVector,
690-
_add::MulAddMul = MulAddMul()) where R
648+
generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
649+
_add::MulAddMul = MulAddMul()) =
650+
_generic_matvecmul!(C, tA, A, B, _add)
651+
652+
function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
653+
_add::MulAddMul = MulAddMul())
691654
require_one_based_indexing(C, A, B)
692655
mB = length(B)
693656
mA, nA = lapack_size(tA, A)

0 commit comments

Comments
 (0)