Skip to content

Commit 41b1778

Browse files
authored
Combine diag methods for SymTridiagonal (#56014)
Currently, there are two branches, one for an `eltype` that is a `Number`, and the other that deals with generic `eltype`s. They do similar things, so we may combine these, and use branches wherever necessary to retain the performance. We also may replace explicit materialized arrays by generators in `copyto!`. Overall, this improves performance in `diag` for matrices of matrices, whereas the performance in the common case of matrices of numbers remains unchanged. ```julia julia> using StaticArrays, LinearAlgebra julia> s = SMatrix{2,2}(1:4); julia> S = SymTridiagonal(fill(s,100), fill(s,99)); julia> @Btime diag($S); 1.292 μs (5 allocations: 7.16 KiB) # nightly, v"1.12.0-DEV.1317" 685.012 ns (3 allocations: 3.19 KiB) # This PR ``` This PR also allows computing the `diag` for more values of the band index `n`: ```julia julia> diag(S,99) 1-element Vector{SMatrix{2, 2, Int64, 4}}: [0 0; 0 0] ``` This would work as long as `getindex` works for the `SymTridiagonal` for that band, and the zero element may be converted to the `eltype`.
1 parent 055e37e commit 41b1778

File tree

2 files changed

+33
-37
lines changed

2 files changed

+33
-37
lines changed

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -183,44 +183,27 @@ issymmetric(S::SymTridiagonal) = true
183183

184184
tr(S::SymTridiagonal) = sum(symmetric, S.dv)
185185

186-
@noinline function throw_diag_outofboundserror(n, sz)
187-
sz1, sz2 = sz
188-
throw(ArgumentError(LazyString(lazy"requested diagonal, $n, must be at least $(-sz1) ",
189-
lazy"and at most $sz2 for an $(sz1)-by-$(sz2) matrix")))
190-
end
186+
_diagiter(M::SymTridiagonal{<:Number}) = M.dv
187+
_diagiter(M::SymTridiagonal) = (symmetric(x, :U) for x in M.dv)
188+
_eviter_transposed(M::SymTridiagonal{<:Number}) = _evview(M)
189+
_eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M))
191190

192-
function diag(M::SymTridiagonal{T}, n::Integer=0) where T<:Number
193-
# every branch call similar(..., ::Int) to make sure the
194-
# same vector type is returned independent of n
195-
absn = abs(n)
196-
if absn == 0
197-
return copyto!(similar(M.dv, length(M.dv)), M.dv)
198-
elseif absn == 1
199-
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
200-
elseif absn <= size(M,1)
201-
v = similar(M.dv, size(M,1)-absn)
202-
for i in eachindex(v)
203-
v[i] = M[BandIndex(n,i)]
204-
end
205-
return v
206-
else
207-
throw_diag_outofboundserror(n, size(M))
208-
end
209-
end
210191
function diag(M::SymTridiagonal, n::Integer=0)
211192
# every branch call similar(..., ::Int) to make sure the
212193
# same vector type is returned independent of n
194+
v = similar(M.dv, max(0, length(M.dv)-abs(n)))
213195
if n == 0
214-
return copyto!(similar(M.dv, length(M.dv)), symmetric.(M.dv, :U))
196+
return copyto!(v, _diagiter(M))
215197
elseif n == 1
216-
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
198+
return copyto!(v, _evview(M))
217199
elseif n == -1
218-
return copyto!(similar(M.ev, length(M.dv)-1), transpose.(_evview(M)))
219-
elseif n <= size(M,1)
220-
throw(ArgumentError("requested diagonal contains undefined zeros of an array type"))
200+
return copyto!(v, _eviter_transposed(M))
221201
else
222-
throw_diag_outofboundserror(n, size(M))
202+
for i in eachindex(v)
203+
v[i] = M[BandIndex(n,i)]
204+
end
223205
end
206+
return v
224207
end
225208

226209
+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, _evview(A)+_evview(B))

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,13 @@ end
287287
@test (@inferred diag(A, 1))::typeof(d) == (mat_type == Tridiagonal ? du : dl)
288288
@test (@inferred diag(A, -1))::typeof(d) == dl
289289
@test (@inferred diag(A, n-1))::typeof(d) == zeros(elty, 1)
290-
@test_throws ArgumentError diag(A, -n - 1)
291-
@test_throws ArgumentError diag(A, n + 1)
290+
if A isa SymTridiagonal
291+
@test isempty(@inferred diag(A, -n - 1))
292+
@test isempty(@inferred diag(A, n + 1))
293+
else
294+
@test_throws ArgumentError diag(A, -n - 1)
295+
@test_throws ArgumentError diag(A, n + 1)
296+
end
292297
GA = mat_type == Tridiagonal ? mat_type(GenericArray.((dl, d, du))...) : mat_type(GenericArray.((d, dl))...)
293298
@test (@inferred diag(GA))::typeof(GenericArray(d)) == GenericArray(d)
294299
@test (@inferred diag(GA, -1))::typeof(GenericArray(d)) == GenericArray(dl)
@@ -501,10 +506,11 @@ end
501506
@test @inferred diag(A, 1) == fill(M, n-1)
502507
@test @inferred diag(A, 0) == fill(Symmetric(M), n)
503508
@test @inferred diag(A, -1) == fill(transpose(M), n-1)
504-
@test_throws ArgumentError diag(A, -2)
505-
@test_throws ArgumentError diag(A, 2)
506-
@test_throws ArgumentError diag(A, n+1)
507-
@test_throws ArgumentError diag(A, -n-1)
509+
@test_broken diag(A, -2) == fill(M, n-2)
510+
@test_broken diag(A, 2) == fill(M, n-2)
511+
@test isempty(@inferred diag(A, n+1))
512+
@test isempty(@inferred diag(A, -n-1))
513+
508514
A[1,1] = Symmetric(2M)
509515
@test A[1,1] == Symmetric(2M)
510516
@test_throws ArgumentError A[1,1] = M
@@ -519,8 +525,8 @@ end
519525
@test @inferred diag(A, 1) == fill(M, n-1)
520526
@test @inferred diag(A, 0) == fill(M, n)
521527
@test @inferred diag(A, -1) == fill(M, n-1)
522-
@test_throws MethodError diag(A, -2)
523-
@test_throws MethodError diag(A, 2)
528+
@test_broken diag(A, -2) == fill(M, n-2)
529+
@test_broken diag(A, 2) == fill(M, n-2)
524530
@test_throws ArgumentError diag(A, n+1)
525531
@test_throws ArgumentError diag(A, -n-1)
526532

@@ -532,6 +538,13 @@ end
532538
A = Tridiagonal(ev, dv, ev)
533539
@test A == Matrix{eltype(A)}(A)
534540
end
541+
542+
M = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
543+
S = SymTridiagonal(fill(M,4), fill(M,3))
544+
@test diag(S,2) == fill(zero(M), 2)
545+
@test diag(S,-2) == fill(zero(M), 2)
546+
@test isempty(diag(S,4))
547+
@test isempty(diag(S,-4))
535548
end
536549

537550
@testset "Issue 12068" begin

0 commit comments

Comments
 (0)