@@ -363,7 +363,7 @@ function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
363363 require_one_based_indexing (A)
364364 dd = D. diag
365365 m, n = size (A, 1 ), size (A, 2 )
366- if (k = length (dd)) ≠ n
366+ if (k = length (dd)) != n
367367 throw (DimensionMismatch (" left hand side has $n columns but D is $k by $k " ))
368368 end
369369 @inbounds for j in 1 : n
395395\ (D:: Diagonal , B:: Diagonal ) = ldiv! (similar (B, promote_op (\ , eltype (D), eltype (B))), D, B)
396396/ (A:: Diagonal , D:: Diagonal ) = _rdiv! (similar (A, promote_op (/ , eltype (A), eltype (D))), A, D)
397397function _rdiv! (Dc:: Diagonal , Db:: Diagonal , Da:: Diagonal )
398- n, k = length (Db. diag), length (Db . diag)
398+ n, k = length (Db. diag), length (Da . diag)
399399 n == k || throw (DimensionMismatch (" left hand side has $n columns but D is $k by $k " ))
400400 j = findfirst (iszero, Da. diag)
401401 isnothing (j) || throw (SingularException (j))
@@ -404,6 +404,88 @@ function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
404404end
405405ldiv! (Dc:: Diagonal , Da:: Diagonal , Db:: Diagonal ) = Diagonal (ldiv! (Dc. diag, Da, Db. diag))
406406
407+ # optimizations for (Sym)Tridiagonal and Diagonal
408+ @propagate_inbounds _getudiag (T:: Tridiagonal , i) = T. du[i]
409+ @propagate_inbounds _getudiag (S:: SymTridiagonal , i) = S. ev[i]
410+ @propagate_inbounds _getdiag (T:: Tridiagonal , i) = T. d[i]
411+ @propagate_inbounds _getdiag (S:: SymTridiagonal , i) = symmetric (S. dv[i], :U ):: symmetric_type (eltype (S. dv))
412+ @propagate_inbounds _getldiag (T:: Tridiagonal , i) = T. dl[i]
413+ @propagate_inbounds _getldiag (S:: SymTridiagonal , i) = transpose (S. ev[i])
414+
415+ function (\ )(D:: Diagonal , S:: SymTridiagonal )
416+ T = promote_op (\ , eltype (D), eltype (S))
417+ du = similar (S. ev, T, max (length (S. dv)- 1 , 0 ))
418+ d = similar (S. dv, T, length (S. dv))
419+ dl = similar (S. ev, T, max (length (S. dv)- 1 , 0 ))
420+ ldiv! (Tridiagonal (dl, d, du), D, S)
421+ end
422+ (\ )(D:: Diagonal , T:: Tridiagonal ) = ldiv! (similar (T, promote_op (\ , eltype (D), eltype (T))), D, T)
423+ function ldiv! (T:: Tridiagonal , D:: Diagonal , S:: Union{SymTridiagonal,Tridiagonal} )
424+ m = size (S, 1 )
425+ dd = D. diag
426+ if (k = length (dd)) != m
427+ throw (DimensionMismatch (" diagonal matrix is $k by $k but right hand side has $m rows" ))
428+ end
429+ if length (T. d) != m
430+ throw (DimensionMismatch (" target matrix size $(size (T)) does not match input matrix size $(size (S)) " ))
431+ end
432+ m == 0 && return T
433+ j = findfirst (iszero, dd)
434+ isnothing (j) || throw (SingularException (j))
435+ ddj = dd[1 ]
436+ T. d[1 ] = ddj \ _getdiag (S, 1 )
437+ @inbounds if m > 1
438+ T. du[1 ] = ddj \ _getudiag (S, 1 )
439+ for j in 2 : m- 1
440+ ddj = dd[j]
441+ T. dl[j- 1 ] = ddj \ _getldiag (S, j- 1 )
442+ T. d[j] = ddj \ _getdiag (S, j)
443+ T. du[j] = ddj \ _getudiag (S, j)
444+ end
445+ ddj = dd[m]
446+ T. dl[m- 1 ] = ddj \ _getldiag (S, m- 1 )
447+ T. d[m] = ddj \ _getdiag (S, m)
448+ end
449+ return T
450+ end
451+
452+ function (/ )(S:: SymTridiagonal , D:: Diagonal )
453+ T = promote_op (\ , eltype (D), eltype (S))
454+ du = similar (S. ev, T, max (length (S. dv)- 1 , 0 ))
455+ d = similar (S. dv, T, length (S. dv))
456+ dl = similar (S. ev, T, max (length (S. dv)- 1 , 0 ))
457+ _rdiv! (Tridiagonal (dl, d, du), S, D)
458+ end
459+ (/ )(T:: Tridiagonal , D:: Diagonal ) = _rdiv! (similar (T, promote_op (/ , eltype (T), eltype (D))), T, D)
460+ function _rdiv! (T:: Tridiagonal , S:: Union{SymTridiagonal,Tridiagonal} , D:: Diagonal )
461+ n = size (S, 2 )
462+ dd = D. diag
463+ if (k = length (dd)) != n
464+ throw (DimensionMismatch (" left hand side has $n columns but D is $k by $k " ))
465+ end
466+ if length (T. d) != n
467+ throw (DimensionMismatch (" target matrix size $(size (T)) does not match input matrix size $(size (S)) " ))
468+ end
469+ n == 0 && return T
470+ j = findfirst (iszero, dd)
471+ isnothing (j) || throw (SingularException (j))
472+ ddj = dd[1 ]
473+ T. d[1 ] = _getdiag (S, 1 ) / ddj
474+ @inbounds if n > 1
475+ T. dl[1 ] = _getldiag (S, 1 ) / ddj
476+ for j in 2 : n- 1
477+ ddj = dd[j]
478+ T. dl[j] = _getldiag (S, j) / ddj
479+ T. d[j] = _getdiag (S, j) / ddj
480+ T. du[j- 1 ] = _getudiag (S, j- 1 ) / ddj
481+ end
482+ ddj = dd[n]
483+ T. d[n] = _getdiag (S, n) / ddj
484+ T. du[n- 1 ] = _getudiag (S, n- 1 ) / ddj
485+ end
486+ return T
487+ end
488+
407489# Optimizations for [l/r]mul!, l/rdiv!, *, / and \ between Triangular and Diagonal.
408490# These functions are generally more efficient if we calculate the whole data field.
409491# The following code implements them in a unified pattern to avoid missing.
@@ -623,7 +705,7 @@ dot(A::AbstractMatrix, B::Diagonal) = conj(dot(B, A))
623705
624706function _mapreduce_prod (f, x, D:: Diagonal , y)
625707 if isempty (x) && isempty (D) && isempty (y)
626- return zero (Base . promote_op (f, eltype (x), eltype (D), eltype (y)))
708+ return zero (promote_op (f, eltype (x), eltype (D), eltype (y)))
627709 else
628710 return mapreduce (t -> f (t[1 ], t[2 ], t[3 ]), + , zip (x, D. diag, y))
629711 end
0 commit comments