Skip to content

Commit 67c93b9

Browse files
authored
diag for BandedMatrixes for off-limit bands (#56065)
Currently, one can only obtain the `diag` for a `BandedMatrix` (such as a `Diagonal`) when the band index is bounded by the size of the matrix. This PR relaxes this requirement to match the behavior for arrays, where `diag` returns an empty vector for a large band index instead of throwing an error. ```julia julia> D = Diagonal(ones(4)) 4×4 Diagonal{Float64, Vector{Float64}}: 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 julia> diag(D, 10) Float64[] julia> diag(Array(D), 10) Float64[] ``` Something similar for `SymTridiagonal` is being done in #56014
1 parent 6029173 commit 67c93b9

File tree

6 files changed

+21
-35
lines changed

6 files changed

+21
-35
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,20 +404,17 @@ end
404404
function diag(M::Bidiagonal, n::Integer=0)
405405
# every branch call similar(..., ::Int) to make sure the
406406
# same vector type is returned independent of n
407+
v = similar(M.dv, max(0, length(M.dv)-abs(n)))
407408
if n == 0
408-
return copyto!(similar(M.dv, length(M.dv)), M.dv)
409+
copyto!(v, M.dv)
409410
elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L')
410-
return copyto!(similar(M.ev, length(M.ev)), M.ev)
411+
copyto!(v, M.ev)
411412
elseif -size(M,1) <= n <= size(M,1)
412-
v = similar(M.dv, size(M,1)-abs(n))
413413
for i in eachindex(v)
414414
v[i] = M[BandIndex(n,i)]
415415
end
416-
return v
417-
else
418-
throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-size(M, 1)) ",
419-
lazy"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
420416
end
417+
return v
421418
end
422419

423420
function +(A::Bidiagonal, B::Bidiagonal)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -773,18 +773,15 @@ permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D
773773
function diag(D::Diagonal, k::Integer=0)
774774
# every branch call similar(..., ::Int) to make sure the
775775
# same vector type is returned independent of k
776+
v = similar(D.diag, max(0, length(D.diag)-abs(k)))
776777
if k == 0
777-
return copyto!(similar(D.diag, length(D.diag)), D.diag)
778-
elseif -size(D,1) <= k <= size(D,1)
779-
v = similar(D.diag, size(D,1)-abs(k))
778+
copyto!(v, D.diag)
779+
else
780780
for i in eachindex(v)
781781
v[i] = D[BandIndex(k, i)]
782782
end
783-
return v
784-
else
785-
throw(ArgumentError(LazyString(lazy"requested diagonal, $k, must be at least $(-size(D, 1)) ",
786-
lazy"and at most $(size(D, 2)) for an $(size(D, 1))-by-$(size(D, 2)) matrix")))
787783
end
784+
return v
788785
end
789786
tr(D::Diagonal) = sum(tr, D.diag)
790787
det(D::Diagonal) = prod(det, D.diag)

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -662,22 +662,19 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y)
662662
function diag(M::Tridiagonal, n::Integer=0)
663663
# every branch call similar(..., ::Int) to make sure the
664664
# same vector type is returned independent of n
665+
v = similar(M.d, max(0, length(M.d)-abs(n)))
665666
if n == 0
666-
return copyto!(similar(M.d, length(M.d)), M.d)
667+
copyto!(v, M.d)
667668
elseif n == -1
668-
return copyto!(similar(M.dl, length(M.dl)), M.dl)
669+
copyto!(v, M.dl)
669670
elseif n == 1
670-
return copyto!(similar(M.du, length(M.du)), M.du)
671+
copyto!(v, M.du)
671672
elseif abs(n) <= size(M,1)
672-
v = similar(M.d, size(M,1)-abs(n))
673673
for i in eachindex(v)
674674
v[i] = M[BandIndex(n,i)]
675675
end
676-
return v
677-
else
678-
throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-size(M, 1)) ",
679-
lazy"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
680676
end
677+
return v
681678
end
682679

683680
@inline function Base.isassigned(A::Tridiagonal, i::Int, j::Int)

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ Random.seed!(1)
398398
@test (@inferred diag(T))::typeof(dv) == dv
399399
@test (@inferred diag(T, uplo === :U ? 1 : -1))::typeof(dv) == ev
400400
@test (@inferred diag(T,2))::typeof(dv) == zeros(elty, n-2)
401-
@test_throws ArgumentError diag(T, -n - 1)
402-
@test_throws ArgumentError diag(T, n + 1)
401+
@test isempty(@inferred diag(T, -n - 1))
402+
@test isempty(@inferred diag(T, n + 1))
403403
# test diag with another wrapped vector type
404404
gdv, gev = GenericArray(dv), GenericArray(ev)
405405
G = Bidiagonal(gdv, gev, uplo)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ Random.seed!(1)
109109
end
110110

111111
@testset "diag" begin
112-
@test_throws ArgumentError diag(D, n+1)
113-
@test_throws ArgumentError diag(D, -n-1)
112+
@test isempty(@inferred diag(D, n+1))
113+
@test isempty(@inferred diag(D, -n-1))
114114
@test (@inferred diag(D))::typeof(dd) == dd
115115
@test (@inferred diag(D, 0))::typeof(dd) == dd
116116
@test (@inferred diag(D, 1))::typeof(dd) == zeros(elty, n-1)

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,8 @@ end
287287
@test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl)
288288
@test (@inferred diag(A, -1))::typeof(d) == dl
289289
@test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1)
290-
if A isa SymTridiagonal
291-
@test isempty(@inferred diag(A, -n - 1))
292-
@test isempty(@inferred diag(A, n + 1))
293-
else
294-
@test_throws ArgumentError diag(A, -n - 1)
295-
@test_throws ArgumentError diag(A, n + 1)
296-
end
290+
@test isempty(@inferred diag(A, -n - 1))
291+
@test isempty(@inferred diag(A, n + 1))
297292
GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...)
298293
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
299294
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
@@ -527,8 +522,8 @@ end
527522
@test @inferred diag(A, -1) == fill(M, n-1)
528523
@test_broken diag(A, -2) == fill(M, n-2)
529524
@test_broken diag(A, 2) == fill(M, n-2)
530-
@test_throws ArgumentError diag(A, n+1)
531-
@test_throws ArgumentError diag(A, -n-1)
525+
@test isempty(@inferred diag(A, n+1))
526+
@test isempty(@inferred diag(A, -n-1))
532527

533528
for n in 0:2
534529
dv, ev = fill(M, n), fill(M, max(n-1,0))

0 commit comments

Comments
 (0)