Skip to content

Commit 5036b77

Browse files
jishnubKristofferC
authored andcommitted
Fix tr for Symmetric/Hermitian block matrices (#55522)
Since `Symmetric` and `Hermitian` symmetrize the diagonal elements of the parent, we can't forward `tr` to the parent unless it is already symmetric. This limits the existing `tr` methods to matrices of `Number`s, which is the common use-case. `tr` for `Symmetric` block matrices would now use the fallback implementation that explicitly computes the `diag`. This resolves the following discrepancy: ```julia julia> S = Symmetric(fill([1 2; 3 4], 3, 3)) 3×3 Symmetric{AbstractMatrix, Matrix{Matrix{Int64}}}: [1 2; 2 4] [1 2; 3 4] [1 2; 3 4] [1 3; 2 4] [1 2; 2 4] [1 2; 3 4] [1 3; 2 4] [1 3; 2 4] [1 2; 2 4] julia> tr(S) 2×2 Matrix{Int64}: 3 6 9 12 julia> sum(diag(S)) 2×2 Symmetric{Int64, Matrix{Int64}}: 3 6 6 12 ``` (cherry picked from commit 9738bc7)
1 parent fa29d0a commit 5036b77

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
397397
Base.copy(A::Transpose{<:Any,<:Hermitian}) =
398398
Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))
399399

400-
tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
401-
tr(A::Hermitian) = real(tr(A.data))
400+
tr(A::Symmetric{<:Number}) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
401+
tr(A::Hermitian{<:Number}) = real(tr(A.data))
402402

403403
Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo)
404404
Base.conj!(A::HermOrSym) = typeof(A)(conj!(A.data), A.uplo)

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,4 +911,15 @@ end
911911
@test LinearAlgebra.hermitian(A, :L) === Hermitian(A, :L)
912912
end
913913

914+
@testset "tr for block matrices" begin
915+
m = [1 2; 3 4]
916+
for b in (m, m * (1 + im))
917+
M = fill(b, 3, 3)
918+
for ST in (Symmetric, Hermitian)
919+
S = ST(M)
920+
@test tr(S) == sum(diag(S))
921+
end
922+
end
923+
end
924+
914925
end # module TestSymmetric

0 commit comments

Comments
 (0)