@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238238Base. isstored (A:: UpperOrLowerTriangular , i:: Int , j:: Int ) =
239239 _shouldforwardindex (A, i, j) ? Base. isstored (A. data, i, j) : false
240240
241- @propagate_inbounds getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T} =
242- _shouldforwardindex (A, i, j) ? A. data[i,j] : ifelse (i == j, oneunit (T), zero (T))
243- @propagate_inbounds getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int ) =
244- _shouldforwardindex (A, i, j) ? A. data[i,j] : diagzero (A,i,j)
241+ @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T}
242+ if _shouldforwardindex (A, i, j)
243+ A. data[i,j]
244+ else
245+ @boundscheck checkbounds (A, i, j)
246+ ifelse (i == j, oneunit (T), zero (T))
247+ end
248+ end
249+ @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int )
250+ if _shouldforwardindex (A, i, j)
251+ A. data[i,j]
252+ else
253+ @boundscheck checkbounds (A, i, j)
254+ @inbounds diagzero (A,i,j)
255+ end
256+ end
245257
246258_shouldforwardindex (U:: UpperTriangular , b:: BandIndex ) = b. band >= 0
247259_shouldforwardindex (U:: LowerTriangular , b:: BandIndex ) = b. band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250262
251263# these specialized getindex methods enable constant-propagation of the band
252264Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , b:: BandIndex ) where {T}
253- _shouldforwardindex (A, b) ? A. data[b] : ifelse (b. band == 0 , oneunit (T), zero (T))
265+ if _shouldforwardindex (A, b)
266+ A. data[b]
267+ else
268+ @boundscheck checkbounds (A, b)
269+ ifelse (b. band == 0 , oneunit (T), zero (T))
270+ end
254271end
255272Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , b:: BandIndex )
256- _shouldforwardindex (A, b) ? A. data[b] : diagzero (A. data, b)
273+ if _shouldforwardindex (A, b)
274+ A. data[b]
275+ else
276+ @boundscheck checkbounds (A, b)
277+ @inbounds diagzero (A, b)
278+ end
257279end
258280
259281_zero_triangular_half_str (:: Type{<:UpperOrUnitUpperTriangular} ) = " lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265287 throw (ArgumentError (
266288 lazy " cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)" ))
267289end
268- @noinline function throw_nononeerror (T, @nospecialize (x), i, j)
290+ @noinline function throw_nonuniterror (T, @nospecialize (x), i, j)
291+ check_compatible_type (T, x)
269292 Tn = nameof (T)
270293 throw (ArgumentError (
271294 lazy " cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)" ))
272295end
296+ function check_compatible_type (T, @nospecialize (x))
297+ ET = eltype (T)
298+ convert (ET, x) # check that the types are compatible with setindex!
299+ end
273300
274301@propagate_inbounds function setindex! (A:: UpperTriangular , x, i:: Integer , j:: Integer )
275302 if i > j
303+ @boundscheck checkbounds (A, i, j)
276304 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
277305 else
278306 A. data[i,j] = x
282310
283311@propagate_inbounds function setindex! (A:: UnitUpperTriangular , x, i:: Integer , j:: Integer )
284312 if i > j
313+ @boundscheck checkbounds (A, i, j)
285314 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
286315 elseif i == j
287- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
316+ @boundscheck checkbounds (A, i, j)
317+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
288318 else
289319 A. data[i,j] = x
290320 end
293323
294324@propagate_inbounds function setindex! (A:: LowerTriangular , x, i:: Integer , j:: Integer )
295325 if i < j
326+ @boundscheck checkbounds (A, i, j)
296327 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
297328 else
298329 A. data[i,j] = x
302333
303334@propagate_inbounds function setindex! (A:: UnitLowerTriangular , x, i:: Integer , j:: Integer )
304335 if i < j
336+ @boundscheck checkbounds (A, i, j)
305337 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
306338 elseif i == j
307- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
339+ @boundscheck checkbounds (A, i, j)
340+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
308341 else
309342 A. data[i,j] = x
310343 end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560593 @eval @inline function _copy! (A:: $UT , B:: $T )
561594 for dind in diagind (A, IndexStyle (A))
562595 if A[dind] != B[dind]
563- throw_nononeerror (typeof (A), B[dind], Tuple (dind)... )
596+ throw_nonuniterror (typeof (A), B[dind], Tuple (dind)... )
564597 end
565598 end
566599 _copy! ($ T (parent (A)), B)
@@ -740,7 +773,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
740773 checksize1 (A, B)
741774 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
742775 for j in axes (B. data,2 )
743- @inbounds _modify! (_add, c, A, (j,j))
776+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
744777 for i in firstindex (B. data,1 ): (j - 1 )
745778 @inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
746779 end
@@ -751,7 +784,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
751784 checksize1 (A, B)
752785 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
753786 for j in axes (B. data,2 )
754- @inbounds _modify! (_add, c, A, (j,j))
787+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
755788 for i in firstindex (B. data,1 ): (j - 1 )
756789 @inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
757790 end
@@ -782,7 +815,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
782815 checksize1 (A, B)
783816 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
784817 for j in axes (B. data,2 )
785- @inbounds _modify! (_add, c, A, (j,j))
818+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
786819 for i in (j + 1 ): lastindex (B. data,1 )
787820 @inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
788821 end
@@ -793,7 +826,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
793826 checksize1 (A, B)
794827 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
795828 for j in axes (B. data,2 )
796- @inbounds _modify! (_add, c, A, (j,j))
829+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
797830 for i in (j + 1 ): lastindex (B. data,1 )
798831 @inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
799832 end
@@ -803,36 +836,52 @@ end
803836
804837function _trirdiv! (A:: UpperTriangular , B:: UpperOrUnitUpperTriangular , c:: Number )
805838 checksize1 (A, B)
839+ isunit = B isa UnitUpperTriangular
806840 for j in axes (B,2 )
807- for i in firstindex (B,1 ): j
808- @inbounds A[i, j] = B[i, j] / c
841+ for i in firstindex (B,1 ): j- isunit
842+ @inbounds A. data[i, j] = B. data[i, j] / c
843+ end
844+ if isunit
845+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
809846 end
810847 end
811848 return A
812849end
813850function _trirdiv! (A:: LowerTriangular , B:: LowerOrUnitLowerTriangular , c:: Number )
814851 checksize1 (A, B)
852+ isunit = B isa UnitLowerTriangular
815853 for j in axes (B,2 )
816- for i in j: lastindex (B,1 )
817- @inbounds A[i, j] = B[i, j] / c
854+ if isunit
855+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
856+ end
857+ for i in j+ isunit: lastindex (B,1 )
858+ @inbounds A. data[i, j] = B. data[i, j] / c
818859 end
819860 end
820861 return A
821862end
822863function _trildiv! (A:: UpperTriangular , c:: Number , B:: UpperOrUnitUpperTriangular )
823864 checksize1 (A, B)
865+ isunit = B isa UnitUpperTriangular
824866 for j in axes (B,2 )
825- for i in firstindex (B,1 ): j
826- @inbounds A[i, j] = c \ B[i, j]
867+ for i in firstindex (B,1 ): j- isunit
868+ @inbounds A. data[i, j] = c \ B. data[i, j]
869+ end
870+ if isunit
871+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
827872 end
828873 end
829874 return A
830875end
831876function _trildiv! (A:: LowerTriangular , c:: Number , B:: LowerOrUnitLowerTriangular )
832877 checksize1 (A, B)
878+ isunit = B isa UnitLowerTriangular
833879 for j in axes (B,2 )
834- for i in j: lastindex (B,1 )
835- @inbounds A[i, j] = c \ B[i, j]
880+ if isunit
881+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
882+ end
883+ for i in j+ isunit: lastindex (B,1 )
884+ @inbounds A. data[i, j] = c \ B. data[i, j]
836885 end
837886 end
838887 return A
0 commit comments