Skip to content

Commit 7615d4c

Browse files
tkfandreasnoack
authored andcommitted
Implement multiply-add interface in LinearAlgebra (#29634)
* Multiply-add interface for BLAS.herk! * Multiply-add interface for BLAS.syrk! * Multiply-add interface for gemv! * Fix UndefRefError from C[i,j] It may not be defined for Matrix{BigFloat}. * Multiply-add interface for BLAS.gemm! * Do not assume *(::Bool, ::eltype(C)) exists * Test multiply-add interface * Implement mul! in terms of addmul! * Document addmul! * Add multiply-add interface for symmetric matrices * Add multiply-add interface for Number and UniformScaling * Use lmul! for beta * C; eltype may not be commutative * Add _lmul_or_fill! * Add multiply-add interface for diagonal matrices * Add multiply-add interface for bi- and tri-diagonal matrices * Add multiply-add interface for triangular matrices * Test multiply-add interface in test/generic.jl * Fix addmul!(C, s::Number, X, alpha, beta) and addmul!(C, X, s::Number, alpha, beta) * Special-case alpha=1 beta=0 using type parameter * Test multiply-add interface in test/uniformscaling.jl * Test multiply-add interface in test/diagonal.jl * Use addmul! in SparseArrays * Systematically test addmul! * Make MulAddMul benchmark-friendly * Fix _modify! docstring * Use MulAddMul in A_mul_B_td! * Comment out broken test_broken * Relax rtol based on eltype of matrices A, B, C * Pass around MulAddMul instead of alpha and beta for type stability * Inline functions between *(::Matrix, ::Matrix) and gemm_wrapper! This is required for recovering the performance of current master. Checked with: A = [1. 0; 2 0] @benchmark $A * $A * Annotate argument type MulAddMul * Inline all addmul! * Construct MulAddMul outside A_mul_B_td! * Add multiply-add interface in test/tridiag.jl * Mention combined multiply-add in NEWS.md [ci skip] * Mention that mul!(C, A, B, α, β) is deprecated [ci skip] * Change API definition to C = ABα + Cβ * Fix indentation * Fix triangular.jl * Fix UndefRefError in multiplication with Diagonal * Define Base.convert(::Type{Quaternion{T}}, s::Real) This is required for testset "* and mul! for non-commutative scaling" in LinearAlgebra/test/generic.jl to pass. * Rename: addmul! -> mul! * Fix doctest * Workaround broadcast error with e.g., triangular matrix of BigFloat
1 parent f2c3d4d commit 7615d4c

File tree

18 files changed

+878
-287
lines changed

18 files changed

+878
-287
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1717
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
1818
using Base: hvcat_fill, IndexLinear, promote_op, promote_typeof,
1919
@propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing
20-
using Base.Broadcast: Broadcasted
20+
using Base.Broadcast: Broadcasted, broadcasted
2121

2222
export
2323
# Modules

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 70 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -338,23 +338,23 @@ end
338338

339339
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
340340
const BiTri = Union{Bidiagonal,Tridiagonal}
341-
mul!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym) = A_mul_B_td!(C, A, B)
342-
mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym) = A_mul_B_td!(C, A, B)
343-
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym) = A_mul_B_td!(C, A, B)
344-
mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym) = A_mul_B_td!(C, A, B)
345-
mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym) = A_mul_B_td!(C, A, B)
346-
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::BiTriSym) = A_mul_B_td!(C, A, B)
347-
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::BiTriSym) = A_mul_B_td!(C, A, B)
348-
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractTriangular}, B::BiTriSym) = A_mul_B_td!(C, A, B)
349-
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractTriangular}, B::BiTriSym) = A_mul_B_td!(C, A, B)
350-
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym) = A_mul_B_td!(C, A, B)
351-
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym) = A_mul_B_td!(C, A, B)
352-
mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector) = A_mul_B_td!(C, A, B)
353-
mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
354-
mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat) = A_mul_B_td!(C, A, B)
355-
mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}) = A_mul_B_td!(C, A, B) # around bidiag line 330
356-
mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}) = A_mul_B_td!(C, A, B)
357-
mul!(C::AbstractVector, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}) = throw(MethodError(mul!, (C, A, B)))
341+
@inline mul!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
342+
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
343+
@inline mul!(C::AbstractMatrix, A::AbstractTriangular, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
344+
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
345+
@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
346+
@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
347+
@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
348+
@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractTriangular}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
349+
@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractTriangular}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
350+
@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
351+
@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
352+
@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
353+
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractVecOrMat, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
354+
@inline mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
355+
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) # around bidiag line 330
356+
@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta))
357+
@inline mul!(C::AbstractVector, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = throw(MethodError(mul!, (C, A, B)), MulAddMul(alpha, beta))
358358

359359
function check_A_mul_B!_sizes(C, A, B)
360360
require_one_based_indexing(C)
@@ -386,11 +386,15 @@ function _diag(A::Bidiagonal, k)
386386
end
387387
end
388388

389-
function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym)
389+
function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym,
390+
_add::MulAddMul = MulAddMul())
390391
check_A_mul_B!_sizes(C, A, B)
391392
n = size(A,1)
392-
n <= 3 && return mul!(C, Array(A), Array(B))
393-
fill!(C, zero(eltype(C)))
393+
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
394+
# We use `_rmul_or_fill!` instead of `_modify!` here since using
395+
# `_modify!` in the following loop will not update the
396+
# off-diagonal elements for non-zero beta.
397+
_rmul_or_fill!(C, _add.beta)
394398
Al = _diag(A, -1)
395399
Ad = _diag(A, 0)
396400
Au = _diag(A, 1)
@@ -399,14 +403,14 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym)
399403
Bu = _diag(B, 1)
400404
@inbounds begin
401405
# first row of C
402-
C[1,1] = A[1,1]*B[1,1] + A[1, 2]*B[2, 1]
403-
C[1,2] = A[1,1]*B[1,2] + A[1,2]*B[2,2]
404-
C[1,3] = A[1,2]*B[2,3]
406+
C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2, 1])
407+
C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2])
408+
C[1,3] += _add(A[1,2]*B[2,3])
405409
# second row of C
406-
C[2,1] = A[2,1]*B[1,1] + A[2,2]*B[2,1]
407-
C[2,2] = A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]
408-
C[2,3] = A[2,2]*B[2,3] + A[2,3]*B[3,3]
409-
C[2,4] = A[2,3]*B[3,4]
410+
C[2,1] += _add(A[2,1]*B[1,1] + A[2,2]*B[2,1])
411+
C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2])
412+
C[2,3] += _add(A[2,2]*B[2,3] + A[2,3]*B[3,3])
413+
C[2,4] += _add(A[2,3]*B[3,4])
410414
for j in 3:n-2
411415
Ajj₋1 = Al[j-1]
412416
Ajj = Ad[j]
@@ -420,26 +424,27 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym)
420424
Bj₊1j = Bl[j]
421425
Bj₊1j₊1 = Bd[j+1]
422426
Bj₊1j₊2 = Bu[j+1]
423-
C[j,j-2] = Ajj₋1*Bj₋1j₋2
424-
C[j, j-1] = Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1
425-
C[j, j ] = Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j
426-
C[j, j+1] = Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1
427-
C[j, j+2] = Ajj₊1*Bj₊1j₊2
427+
C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2)
428+
C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1)
429+
C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j)
430+
C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1)
431+
C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2)
428432
end
429433
# row before last of C
430-
C[n-1,n-3] = A[n-1,n-2]*B[n-2,n-3]
431-
C[n-1,n-2] = A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]
432-
C[n-1,n-1] = A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]
433-
C[n-1,n ] = A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]
434+
C[n-1,n-3] += _add(A[n-1,n-2]*B[n-2,n-3])
435+
C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2])
436+
C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1])
437+
C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ])
434438
# last row of C
435-
C[n,n-2] = A[n,n-1]*B[n-1,n-2]
436-
C[n,n-1] = A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]
437-
C[n,n ] = A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]
439+
C[n,n-2] += _add(A[n,n-1]*B[n-1,n-2])
440+
C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1])
441+
C[n,n ] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ])
438442
end # inbounds
439443
C
440444
end
441445

442-
function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::Diagonal)
446+
function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::Diagonal,
447+
_add::MulAddMul = MulAddMul())
443448
check_A_mul_B!_sizes(C, A, B)
444449
n = size(A,1)
445450
n <= 3 && return mul!(C, Array(A), Array(B))
@@ -450,29 +455,30 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::Diagonal)
450455
Bd = B.diag
451456
@inbounds begin
452457
# first row of C
453-
C[1,1] = A[1,1]*B[1,1]
454-
C[1,2] = A[1,2]*B[2,2]
458+
_modify!(_add, A[1,1]*B[1,1], C, (1,1))
459+
_modify!(_add, A[1,2]*B[2,2], C, (1,2))
455460
# second row of C
456-
C[2,1] = A[2,1]*B[1,1]
457-
C[2,2] = A[2,2]*B[2,2]
458-
C[2,3] = A[2,3]*B[3,3]
461+
_modify!(_add, A[2,1]*B[1,1], C, (2,1))
462+
_modify!(_add, A[2,2]*B[2,2], C, (2,2))
463+
_modify!(_add, A[2,3]*B[3,3], C, (2,3))
459464
for j in 3:n-2
460-
C[j, j-1] = Al[j-1]*Bd[j-1]
461-
C[j, j ] = Ad[j ]*Bd[j ]
462-
C[j, j+1] = Au[j ]*Bd[j+1]
465+
_modify!(_add, Al[j-1]*Bd[j-1], C, (j, j-1))
466+
_modify!(_add, Ad[j ]*Bd[j ], C, (j, j ))
467+
_modify!(_add, Au[j ]*Bd[j+1], C, (j, j+1))
463468
end
464469
# row before last of C
465-
C[n-1,n-2] = A[n-1,n-2]*B[n-2,n-2]
466-
C[n-1,n-1] = A[n-1,n-1]*B[n-1,n-1]
467-
C[n-1,n ] = A[n-1, n]*B[n ,n ]
470+
_modify!(_add, A[n-1,n-2]*B[n-2,n-2], C, (n-1,n-2))
471+
_modify!(_add, A[n-1,n-1]*B[n-1,n-1], C, (n-1,n-1))
472+
_modify!(_add, A[n-1, n]*B[n ,n ], C, (n-1,n ))
468473
# last row of C
469-
C[n,n-1] = A[n,n-1]*B[n-1,n-1]
470-
C[n,n ] = A[n,n ]*B[n, n ]
474+
_modify!(_add, A[n,n-1]*B[n-1,n-1], C, (n,n-1))
475+
_modify!(_add, A[n,n ]*B[n, n ], C, (n,n ))
471476
end # inbounds
472477
C
473478
end
474479

475-
function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat)
480+
function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat,
481+
_add::MulAddMul = MulAddMul())
476482
require_one_based_indexing(C)
477483
require_one_based_indexing(B)
478484
nA = size(A,1)
@@ -483,28 +489,29 @@ function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat)
483489
if size(C,2) != nB
484490
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
485491
end
486-
nA <= 3 && return mul!(C, Array(A), Array(B))
492+
nA <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
487493
l = _diag(A, -1)
488494
d = _diag(A, 0)
489495
u = _diag(A, 1)
490496
@inbounds begin
491497
for j = 1:nB
492498
b₀, b₊ = B[1, j], B[2, j]
493-
C[1, j] = d[1]*b₀ + u[1]*b₊
499+
_modify!(_add, d[1]*b₀ + u[1]*b₊, C, (1, j))
494500
for i = 2:nA - 1
495501
b₋, b₀, b₊ = b₀, b₊, B[i + 1, j]
496-
C[i, j] = l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊
502+
_modify!(_add, l[i - 1]*b₋ + d[i]*b₀ + u[i]*b₊, C, (i, j))
497503
end
498-
C[nA, j] = l[nA - 1]*b₀ + d[nA]*b₊
504+
_modify!(_add, l[nA - 1]*b₀ + d[nA]*b₊, C, (nA, j))
499505
end
500506
end
501507
C
502508
end
503509

504-
function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
510+
function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym,
511+
_add::MulAddMul = MulAddMul())
505512
check_A_mul_B!_sizes(C, A, B)
506513
n = size(A,1)
507-
n <= 3 && return mul!(C, Array(A), Array(B))
514+
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
508515
m = size(B,2)
509516
Bl = _diag(B, -1)
510517
Bd = _diag(B, 0)
@@ -516,16 +523,16 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym)
516523
Bmm = Bd[m]
517524
Bm₋1m = Bu[m-1]
518525
for i in 1:n
519-
C[i, 1] = A[i,1] * B11 + A[i, 2] * B21
520-
C[i, m] = A[i, m-1] * Bm₋1m + A[i, m] * Bmm
526+
_modify!(_add, A[i,1] * B11 + A[i, 2] * B21, C, (i, 1))
527+
_modify!(_add, A[i, m-1] * Bm₋1m + A[i, m] * Bmm, C, (i, m))
521528
end
522529
# middle columns of C
523530
for j = 2:m-1
524531
Bj₋1j = Bu[j-1]
525532
Bjj = Bd[j]
526533
Bj₊1j = Bl[j]
527534
for i = 1:n
528-
C[i, j] = A[i, j-1] * Bj₋1j + A[i, j]*Bjj + A[i, j+1] * Bj₊1j
535+
_modify!(_add, A[i, j-1] * Bj₋1j + A[i, j]*Bjj + A[i, j+1] * Bj₊1j, C, (i, j))
529536
end
530537
end
531538
end # inbounds

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,9 @@ for (fname, elty) in ((:dgemv_,:Float64),
571571
# CHARACTER TRANS
572572
#* .. Array Arguments ..
573573
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
574-
function gemv!(trans::AbstractChar, alpha::($elty), A::AbstractVecOrMat{$elty}, X::AbstractVector{$elty}, beta::($elty), Y::AbstractVector{$elty})
574+
function gemv!(trans::AbstractChar, alpha::Union{($elty), Bool},
575+
A::AbstractVecOrMat{$elty}, X::AbstractVector{$elty},
576+
beta::Union{($elty), Bool}, Y::AbstractVector{$elty})
575577
require_one_based_indexing(A, X, Y)
576578
m,n = size(A,1),size(A,2)
577579
if trans == 'N' && (length(X) != n || length(Y) != m)
@@ -656,7 +658,10 @@ for (fname, elty) in ((:dgbmv_,:Float64),
656658
# CHARACTER TRANS
657659
# * .. Array Arguments ..
658660
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
659-
function gbmv!(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, beta::($elty), y::AbstractVector{$elty})
661+
function gbmv!(trans::AbstractChar, m::Integer, kl::Integer, ku::Integer,
662+
alpha::Union{($elty), Bool}, A::AbstractMatrix{$elty},
663+
x::AbstractVector{$elty}, beta::Union{($elty), Bool},
664+
y::AbstractVector{$elty})
660665
require_one_based_indexing(A, x, y)
661666
chkstride1(A)
662667
ccall((@blasfunc($fname), libblas), Cvoid,
@@ -704,7 +709,9 @@ for (fname, elty, lib) in ((:dsymv_,:Float64,libblas),
704709
# CHARACTER UPLO
705710
# .. Array Arguments ..
706711
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
707-
function symv!(uplo::AbstractChar, alpha::($elty), A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, beta::($elty), y::AbstractVector{$elty})
712+
function symv!(uplo::AbstractChar, alpha::Union{($elty), Bool},
713+
A::AbstractMatrix{$elty}, x::AbstractVector{$elty},
714+
beta::Union{($elty), Bool}, y::AbstractVector{$elty})
708715
require_one_based_indexing(A, x, y)
709716
m, n = size(A)
710717
if m != n
@@ -756,7 +763,7 @@ symv(ul, A, x)
756763
for (fname, elty) in ((:zhemv_,:ComplexF64),
757764
(:chemv_,:ComplexF32))
758765
@eval begin
759-
function hemv!(uplo::AbstractChar, α::$elty, A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, β::$elty, y::AbstractVector{$elty})
766+
function hemv!(uplo::AbstractChar, α::Union{$elty, Bool}, A::AbstractMatrix{$elty}, x::AbstractVector{$elty}, β::Union{$elty, Bool}, y::AbstractVector{$elty})
760767
require_one_based_indexing(A, x, y)
761768
m, n = size(A)
762769
if m != n
@@ -1112,7 +1119,11 @@ for (gemm, elty) in
11121119
# CHARACTER TRANSA,TRANSB
11131120
# * .. Array Arguments ..
11141121
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
1115-
function gemm!(transA::AbstractChar, transB::AbstractChar, alpha::($elty), A::AbstractVecOrMat{$elty}, B::AbstractVecOrMat{$elty}, beta::($elty), C::AbstractVecOrMat{$elty})
1122+
function gemm!(transA::AbstractChar, transB::AbstractChar,
1123+
alpha::Union{($elty), Bool},
1124+
A::AbstractVecOrMat{$elty}, B::AbstractVecOrMat{$elty},
1125+
beta::Union{($elty), Bool},
1126+
C::AbstractVecOrMat{$elty})
11161127
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
11171128
# error("gemm!: BLAS module requires contiguous matrix columns")
11181129
# end # should this be checked on every call?
@@ -1175,7 +1186,9 @@ for (mfname, elty) in ((:dsymm_,:Float64),
11751186
# CHARACTER SIDE,UPLO
11761187
# .. Array Arguments ..
11771188
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
1178-
function symm!(side::AbstractChar, uplo::AbstractChar, alpha::($elty), A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty}, beta::($elty), C::AbstractMatrix{$elty})
1189+
function symm!(side::AbstractChar, uplo::AbstractChar, alpha::Union{($elty), Bool},
1190+
A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty},
1191+
beta::Union{($elty), Bool}, C::AbstractMatrix{$elty})
11791192
require_one_based_indexing(A, B, C)
11801193
m, n = size(C)
11811194
j = checksquare(A)
@@ -1244,7 +1257,9 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64),
12441257
# CHARACTER SIDE,UPLO
12451258
# .. Array Arguments ..
12461259
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
1247-
function hemm!(side::AbstractChar, uplo::AbstractChar, alpha::($elty), A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty}, beta::($elty), C::AbstractMatrix{$elty})
1260+
function hemm!(side::AbstractChar, uplo::AbstractChar, alpha::Union{($elty), Bool},
1261+
A::AbstractMatrix{$elty}, B::AbstractMatrix{$elty},
1262+
beta::Union{($elty), Bool}, C::AbstractMatrix{$elty})
12481263
require_one_based_indexing(A, B, C)
12491264
m, n = size(C)
12501265
j = checksquare(A)
@@ -1309,8 +1324,8 @@ for (fname, elty) in ((:dsyrk_,:Float64),
13091324
# * .. Array Arguments ..
13101325
# REAL A(LDA,*),C(LDC,*)
13111326
function syrk!(uplo::AbstractChar, trans::AbstractChar,
1312-
alpha::($elty), A::AbstractVecOrMat{$elty},
1313-
beta::($elty), C::AbstractMatrix{$elty})
1327+
alpha::Union{($elty), Bool}, A::AbstractVecOrMat{$elty},
1328+
beta::Union{($elty), Bool}, C::AbstractMatrix{$elty})
13141329
require_one_based_indexing(A, C)
13151330
n = checksquare(C)
13161331
nn = size(A, trans == 'N' ? 1 : 2)
@@ -1366,8 +1381,9 @@ for (fname, elty, relty) in ((:zherk_, :ComplexF64, :Float64),
13661381
# * ..
13671382
# * .. Array Arguments ..
13681383
# COMPLEX A(LDA,*),C(LDC,*)
1369-
function herk!(uplo::AbstractChar, trans::AbstractChar, α::$relty, A::AbstractVecOrMat{$elty},
1370-
β::$relty, C::AbstractMatrix{$elty})
1384+
function herk!(uplo::AbstractChar, trans::AbstractChar,
1385+
α::Union{$relty, Bool}, A::AbstractVecOrMat{$elty},
1386+
β::Union{$relty, Bool}, C::AbstractMatrix{$elty})
13711387
require_one_based_indexing(A, C)
13721388
n = checksquare(C)
13731389
nn = size(A, trans == 'N' ? 1 : 2)

0 commit comments

Comments
 (0)