@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238238Base. isstored (A:: UpperOrLowerTriangular , i:: Int , j:: Int ) =
239239 _shouldforwardindex (A, i, j) ? Base. isstored (A. data, i, j) : false
240240
241- @propagate_inbounds getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T} =
242- _shouldforwardindex (A, i, j) ? A. data[i,j] : ifelse (i == j, oneunit (T), zero (T))
243- @propagate_inbounds getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int ) =
244- _shouldforwardindex (A, i, j) ? A. data[i,j] : diagzero (A,i,j)
241+ @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , i:: Int , j:: Int ) where {T}
242+ if _shouldforwardindex (A, i, j)
243+ A. data[i,j]
244+ else
245+ @boundscheck checkbounds (A, i, j)
246+ ifelse (i == j, oneunit (T), zero (T))
247+ end
248+ end
249+ @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , i:: Int , j:: Int )
250+ if _shouldforwardindex (A, i, j)
251+ A. data[i,j]
252+ else
253+ @boundscheck checkbounds (A, i, j)
254+ @inbounds diagzero (A,i,j)
255+ end
256+ end
245257
246258_shouldforwardindex (U:: UpperTriangular , b:: BandIndex ) = b. band >= 0
247259_shouldforwardindex (U:: LowerTriangular , b:: BandIndex ) = b. band <= 0
@@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250262
251263# these specialized getindex methods enable constant-propagation of the band
252264Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}} , b:: BandIndex ) where {T}
253- _shouldforwardindex (A, b) ? A. data[b] : ifelse (b. band == 0 , oneunit (T), zero (T))
265+ if _shouldforwardindex (A, b)
266+ A. data[b]
267+ else
268+ @boundscheck checkbounds (A, b)
269+ ifelse (b. band == 0 , oneunit (T), zero (T))
270+ end
254271end
255272Base. @constprop :aggressive @propagate_inbounds function getindex (A:: Union{LowerTriangular, UpperTriangular} , b:: BandIndex )
256- _shouldforwardindex (A, b) ? A. data[b] : diagzero (A. data, b)
273+ if _shouldforwardindex (A, b)
274+ A. data[b]
275+ else
276+ @boundscheck checkbounds (A, b)
277+ @inbounds diagzero (A, b)
278+ end
257279end
258280
259281_zero_triangular_half_str (:: Type{<:UpperOrUnitUpperTriangular} ) = " lower"
@@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"
265287 throw (ArgumentError (
266288 lazy " cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)" ))
267289end
268- @noinline function throw_nononeerror (T, @nospecialize (x), i, j)
290+ @noinline function throw_nonuniterror (T, @nospecialize (x), i, j)
291+ check_compatible_type (T, x)
269292 Tn = nameof (T)
270293 throw (ArgumentError (
271294 lazy " cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)" ))
272295end
296+ function check_compatible_type (T, @nospecialize (x))
297+ ET = eltype (T)
298+ convert (ET, x) # check that the types are compatible with setindex!
299+ end
273300
274301@propagate_inbounds function setindex! (A:: UpperTriangular , x, i:: Integer , j:: Integer )
275302 if i > j
303+ @boundscheck checkbounds (A, i, j)
276304 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
277305 else
278306 A. data[i,j] = x
282310
283311@propagate_inbounds function setindex! (A:: UnitUpperTriangular , x, i:: Integer , j:: Integer )
284312 if i > j
313+ @boundscheck checkbounds (A, i, j)
285314 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
286315 elseif i == j
287- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
316+ @boundscheck checkbounds (A, i, j)
317+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
288318 else
289319 A. data[i,j] = x
290320 end
293323
294324@propagate_inbounds function setindex! (A:: LowerTriangular , x, i:: Integer , j:: Integer )
295325 if i < j
326+ @boundscheck checkbounds (A, i, j)
296327 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
297328 else
298329 A. data[i,j] = x
302333
303334@propagate_inbounds function setindex! (A:: UnitLowerTriangular , x, i:: Integer , j:: Integer )
304335 if i < j
336+ @boundscheck checkbounds (A, i, j)
305337 iszero (x) || throw_nonzeroerror (typeof (A), x, i, j)
306338 elseif i == j
307- x == oneunit (x) || throw_nononeerror (typeof (A), x, i, j)
339+ @boundscheck checkbounds (A, i, j)
340+ x == oneunit (eltype (A)) || throw_nonuniterror (typeof (A), x, i, j)
308341 else
309342 A. data[i,j] = x
310343 end
@@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
560593 @eval @inline function _copy! (A:: $UT , B:: $T )
561594 for dind in diagind (A, IndexStyle (A))
562595 if A[dind] != B[dind]
563- throw_nononeerror (typeof (A), B[dind], Tuple (dind)... )
596+ throw_nonuniterror (typeof (A), B[dind], Tuple (dind)... )
564597 end
565598 end
566599 _copy! ($ T (parent (A)), B)
@@ -742,7 +775,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
742775 checksize1 (A, B)
743776 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
744777 for j in axes (B. data,2 )
745- @inbounds _modify! (_add, c, A, (j,j))
778+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
746779 for i in firstindex (B. data,1 ): (j - 1 )
747780 @inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
748781 end
@@ -753,7 +786,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
753786 checksize1 (A, B)
754787 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
755788 for j in axes (B. data,2 )
756- @inbounds _modify! (_add, c, A, (j,j))
789+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
757790 for i in firstindex (B. data,1 ): (j - 1 )
758791 @inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
759792 end
@@ -784,7 +817,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
784817 checksize1 (A, B)
785818 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
786819 for j in axes (B. data,2 )
787- @inbounds _modify! (_add, c, A, (j,j))
820+ @inbounds _modify! (_add, B[ BandIndex ( 0 ,j)] * c, A, (j,j))
788821 for i in (j + 1 ): lastindex (B. data,1 )
789822 @inbounds _modify! (_add, B. data[i,j] * c, A. data, (i,j))
790823 end
@@ -795,7 +828,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
795828 checksize1 (A, B)
796829 _iszero_alpha (_add) && return _rmul_or_fill! (A, _add. beta)
797830 for j in axes (B. data,2 )
798- @inbounds _modify! (_add, c, A, (j,j))
831+ @inbounds _modify! (_add, c * B[ BandIndex ( 0 ,j)] , A, (j,j))
799832 for i in (j + 1 ): lastindex (B. data,1 )
800833 @inbounds _modify! (_add, c * B. data[i,j], A. data, (i,j))
801834 end
@@ -805,36 +838,52 @@ end
805838
806839function _trirdiv! (A:: UpperTriangular , B:: UpperOrUnitUpperTriangular , c:: Number )
807840 checksize1 (A, B)
841+ isunit = B isa UnitUpperTriangular
808842 for j in axes (B,2 )
809- for i in firstindex (B,1 ): j
810- @inbounds A[i, j] = B[i, j] / c
843+ for i in firstindex (B,1 ): j- isunit
844+ @inbounds A. data[i, j] = B. data[i, j] / c
845+ end
846+ if isunit
847+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
811848 end
812849 end
813850 return A
814851end
815852function _trirdiv! (A:: LowerTriangular , B:: LowerOrUnitLowerTriangular , c:: Number )
816853 checksize1 (A, B)
854+ isunit = B isa UnitLowerTriangular
817855 for j in axes (B,2 )
818- for i in j: lastindex (B,1 )
819- @inbounds A[i, j] = B[i, j] / c
856+ if isunit
857+ @inbounds A. data[j, j] = B[BandIndex (0 ,j)] / c
858+ end
859+ for i in j+ isunit: lastindex (B,1 )
860+ @inbounds A. data[i, j] = B. data[i, j] / c
820861 end
821862 end
822863 return A
823864end
824865function _trildiv! (A:: UpperTriangular , c:: Number , B:: UpperOrUnitUpperTriangular )
825866 checksize1 (A, B)
867+ isunit = B isa UnitUpperTriangular
826868 for j in axes (B,2 )
827- for i in firstindex (B,1 ): j
828- @inbounds A[i, j] = c \ B[i, j]
869+ for i in firstindex (B,1 ): j- isunit
870+ @inbounds A. data[i, j] = c \ B. data[i, j]
871+ end
872+ if isunit
873+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
829874 end
830875 end
831876 return A
832877end
833878function _trildiv! (A:: LowerTriangular , c:: Number , B:: LowerOrUnitLowerTriangular )
834879 checksize1 (A, B)
880+ isunit = B isa UnitLowerTriangular
835881 for j in axes (B,2 )
836- for i in j: lastindex (B,1 )
837- @inbounds A[i, j] = c \ B[i, j]
882+ if isunit
883+ @inbounds A. data[j, j] = c \ B[BandIndex (0 ,j)]
884+ end
885+ for i in j+ isunit: lastindex (B,1 )
886+ @inbounds A. data[i, j] = c \ B. data[i, j]
838887 end
839888 end
840889 return A
0 commit comments