@@ -397,13 +397,13 @@ function lmul!(D::Diagonal, T::Tridiagonal)
397397 return T
398398end
399399
400- @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, _add :: MulAddMul )
400+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, alpha :: Number , beta :: Number )
401401 @inbounds for j in axes (B, 2 )
402402 @simd for i in axes (B, 1 )
403- _modify! (_add , D. diag[i] * B[i,j], out, (i,j))
403+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[i] * B[i,j], out, (i,j))
404404 end
405405 end
406- out
406+ return out
407407end
408408_has_matching_zeros (out:: UpperOrUnitUpperTriangular , A:: UpperOrUnitUpperTriangular ) = true
409409_has_matching_zeros (out:: LowerOrUnitLowerTriangular , A:: LowerOrUnitLowerTriangular ) = true
@@ -418,116 +418,118 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
418418end
419419_rowrange_tri_zeros (B:: UpperOrUnitUpperTriangular , col) = col+ 1 : size (B,1 )
420420_rowrange_tri_zeros (B:: LowerOrUnitLowerTriangular , col) = 1 : col- 1
421- function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add :: MulAddMul )
421+ function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , alpha :: Number , beta :: Number )
422422 isunit = B isa UnitUpperOrUnitLowerTriangular
423423 out_maybeparent, B_maybeparent = _has_matching_zeros (out, B) ? (parent (out), parent (B)) : (out, B)
424424 for j in axes (B, 2 )
425425 # store the diagonal separately for unit triangular matrices
426426 if isunit
427- @inbounds _modify! (_add , D. diag[j] * B[j,j], out, (j,j))
427+ @inbounds @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[j] * B[j,j], out, (j,j))
428428 end
429429 # The indices of out corresponding to the stored indices of B
430430 rowrange = _rowrange_tri_stored (B, j)
431431 @inbounds @simd for i in rowrange
432- _modify! (_add , D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
432+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
433433 end
434434 # Fill the indices of out corresponding to the zeros of B
435435 # we only fill these if out and B don't have matching zeros
436436 if ! _has_matching_zeros (out, B)
437437 rowrange = _rowrange_tri_zeros (B, j)
438438 @inbounds @simd for i in rowrange
439- _modify! (_add , D. diag[i] * B[i,j], out, (i,j))
439+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , D. diag[i] * B[i,j], out, (i,j))
440440 end
441441 end
442442 end
443443 return out
444444end
445445
446- @inline function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
447- beta = _add. beta
448- _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
446+ @inline function __muldiag_nonzeroalpha_right! (out, A, D:: Diagonal , alpha:: Number , beta:: Number )
449447 @inbounds for j in axes (A, 2 )
450- dja = _add (D. diag[j])
448+ dja = @stable_muladdmul MulAddMul (alpha, false ) (D. diag[j])
451449 @simd for i in axes (A, 1 )
452- _modify! (_add_aisone , A[i,j] * dja, out, (i,j))
450+ @stable_muladdmul _modify! (MulAddMul ( true ,beta) , A[i,j] * dja, out, (i,j))
453451 end
454452 end
455- out
453+ return out
454+ end
455+
456+ function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , alpha:: Number , beta:: Number )
457+ __muldiag_nonzeroalpha_right! (out, A, D, alpha, beta)
456458end
457- function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add :: MulAddMul{ais1,bis0} ) where {ais1,bis0}
459+ function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , alpha :: Number , beta :: Number )
458460 isunit = A isa UnitUpperOrUnitLowerTriangular
459- beta = _add. beta
460- # since alpha is multiplied to the diagonal element of D,
461- # we may skip alpha in the second multiplication by setting ais1 to true
462- _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
463461 # if both A and out have the same upper/lower triangular structure,
464462 # we may directly read and write from the parents
465- out_maybeparent, A_maybeparent = _has_matching_zeros (out, A) ? (parent (out), parent (A)) : (out, A)
463+ out_maybeparent, A_maybeparent = _has_matching_zeros (out, A) ? (parent (out), parent (A)) : (out, A)
466464 for j in axes (A, 2 )
467- dja = _add (@inbounds D. diag[j])
465+ dja = @stable_muladdmul MulAddMul (alpha, false ) (@inbounds D. diag[j])
468466 # store the diagonal separately for unit triangular matrices
469467 if isunit
470- @inbounds _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
468+ # since alpha is multiplied to the diagonal element of D,
469+ # we may skip alpha in the second multiplication by setting ais1 to true
470+ @inbounds @stable_muladdmul _modify! (MulAddMul (true ,beta), A[j,j] * dja, out, (j,j))
471471 end
472472 # indices of out corresponding to the stored indices of A
473473 rowrange = _rowrange_tri_stored (A, j)
474474 @inbounds @simd for i in rowrange
475- _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
475+ # since alpha is multiplied to the diagonal element of D,
476+ # we may skip alpha in the second multiplication by setting ais1 to true
477+ @stable_muladdmul _modify! (MulAddMul (true ,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
476478 end
477479 # Fill the indices of out corresponding to the zeros of A
478480 # we only fill these if out and A don't have matching zeros
479481 if ! _has_matching_zeros (out, A)
480482 rowrange = _rowrange_tri_zeros (A, j)
481483 @inbounds @simd for i in rowrange
482- _modify! (_add_aisone , A[i,j] * dja, out, (i,j))
484+ @stable_muladdmul _modify! (MulAddMul ( true ,beta) , A[i,j] * dja, out, (i,j))
483485 end
484486 end
485487 end
486- out
488+ return out
489+ end
490+
491+ # ambiguity resolution
492+ function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , alpha:: Number , beta:: Number )
493+ __muldiag_nonzeroalpha_right! (out, D1, D2, alpha, beta)
487494end
488495
489- @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add :: MulAddMul )
496+ @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , alpha :: Number , beta :: Number )
490497 d1 = D1. diag
491498 d2 = D2. diag
492499 outd = out. diag
493500 @inbounds @simd for i in eachindex (d1, d2, outd)
494- _modify! (_add , d1[i] * d2[i], outd, i)
501+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , d1[i] * d2[i], outd, i)
495502 end
496- out
497- end
498-
499- # ambiguity resolution
500- @inline function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
501- @inbounds for j in axes (D2, 2 ), i in axes (D2, 1 )
502- _modify! (_add, D1. diag[i] * D2[i,j], out, (i,j))
503- end
504- out
503+ return out
505504end
506505
507- # muldiag mainly handles the zero-alpha case, so that we need only
506+ # muldiag handles the zero-alpha case, so that we need only
508507# specialize the non-trivial case
509- function _mul_diag! (out, A, B, _add )
508+ function _mul_diag! (out, A, B, alpha, beta )
510509 require_one_based_indexing (out, A, B)
511510 _muldiag_size_check (size (out), size (A), size (B))
512- alpha, beta = _add. alpha, _add. beta
513511 if iszero (alpha)
514512 _rmul_or_fill! (out, beta)
515513 else
516- __muldiag_nonzeroalpha! (out, A, B, _add )
514+ __muldiag_nonzeroalpha! (out, A, B, alpha, beta )
517515 end
518516 return out
519517end
520518
521- _mul! (out:: AbstractVecOrMat , D:: Diagonal , V:: AbstractVector , _add) =
522- _mul_diag! (out, D, V, _add)
523- _mul! (out:: AbstractMatrix , D:: Diagonal , B:: AbstractMatrix , _add) =
524- _mul_diag! (out, D, B, _add)
525- _mul! (out:: AbstractMatrix , A:: AbstractMatrix , D:: Diagonal , _add) =
526- _mul_diag! (out, A, D, _add)
527- _mul! (C:: Diagonal , Da:: Diagonal , Db:: Diagonal , _add) =
528- _mul_diag! (C, Da, Db, _add)
529- _mul! (C:: AbstractMatrix , Da:: Diagonal , Db:: Diagonal , _add) =
530- _mul_diag! (C, Da, Db, _add)
519+ _mul! (out:: AbstractVector , D:: Diagonal , V:: AbstractVector , alpha:: Number , beta:: Number ) =
520+ _mul_diag! (out, D, V, alpha, beta)
521+ _mul! (out:: AbstractMatrix , D:: Diagonal , V:: AbstractVector , alpha:: Number , beta:: Number ) =
522+ _mul_diag! (out, D, V, alpha, beta)
523+ for MT in (:AbstractMatrix , :AbstractTriangular )
524+ @eval begin
525+ _mul! (out:: AbstractMatrix , D:: Diagonal , B:: $MT , alpha:: Number , beta:: Number ) =
526+ _mul_diag! (out, D, B, alpha, beta)
527+ _mul! (out:: AbstractMatrix , A:: $MT , D:: Diagonal , alpha:: Number , beta:: Number ) =
528+ _mul_diag! (out, A, D, alpha, beta)
529+ end
530+ end
531+ _mul! (C:: AbstractMatrix , Da:: Diagonal , Db:: Diagonal , alpha:: Number , beta:: Number ) =
532+ _mul_diag! (C, Da, Db, alpha, beta)
531533
532534function (* )(Da:: Diagonal , A:: AbstractMatrix , Db:: Diagonal )
533535 _muldiag_size_check (size (Da), size (A))
0 commit comments