From c96c4948111a54f0af799b4e485da595024add56 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 27 Apr 2025 23:07:39 +0530 Subject: [PATCH 1/6] Out-of-place `triu`/`tril` for `Symmetric` in each branch --- src/symmetric.jl | 24 ++++++++++++------------ src/triangular.jl | 5 +++++ test/symmetric.jl | 23 +++++++++++++++++++++++ 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index 089eefa2..96e9c717 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -520,25 +520,25 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo) # tril/triu function tril(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(A.data'),k) + return tril_maybe_inplace(copy(A.data'),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(A.data'),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(A.data')),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k) end end function tril(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(transpose(A.data)),k) + return tril_maybe_inplace(copy(transpose(A.data)),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k) end end @@ -546,11 +546,11 @@ function triu(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(A.data')),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(A.data'),k) + return triu_maybe_inplace(copy(A.data'),k) else - return triu!(copy(A.data'),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k) end end @@ -558,11 +558,11 @@ function triu(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(transpose(A.data)),k) + return triu_maybe_inplace(copy(transpose(A.data)),k) else - return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k) end end diff --git a/src/triangular.jl b/src/triangular.jl index d82ddd87..025abc2e 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -530,6 +530,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0) return tril!(LowerTriangular(A.data), k) end +tril_maybe_inplace(A, k::Integer=0) = tril(A, k) +triu_maybe_inplace(A, k::Integer=0) = triu(A, k) +tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k) +triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k) + adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data)) adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data)) adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data)) diff --git a/test/symmetric.jl b/test/symmetric.jl index 45125591..cd97af56 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1350,4 +1350,27 @@ end @test LinearAlgebra.uplo(H) == :L end +@testset "triu/tril with immutable arrays" begin + struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T} + a :: A + end + Base.size(A::ImmutableMatrix) = size(A.a) + Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j) + Base.copy(A::ImmutableMatrix) = A + LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a)) + LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a)) + + A = ImmutableMatrix([1 2; 3 4]) + for T in (Symmetric, Hermitian), uplo in (:U, :L) + H = T(A, uplo) + MH = Matrix(H) + @test triu(H,-1) == triu(MH,-1) + @test triu(H) == triu(MH) + @test triu(H,1) == triu(MH,1) + @test tril(H,1) == tril(MH,1) + @test tril(H) == tril(MH) + @test tril(H,-1) == tril(MH,-1) + end +end + end # module TestSymmetric From cd2aa88aea3ba587bb45b3d9cb8ab29bcbabeed0 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 27 Apr 2025 23:09:56 +0530 Subject: [PATCH 2/6] Undo `_conjugation` change --- src/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index 96e9c717..af3f6d2f 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -332,7 +332,7 @@ Base.dataids(A::HermOrSym) = Base.dataids(parent(A)) Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) -_conjugation(::Union{Symmetric, Hermitian{<:Real}}) = transpose +_conjugation(::Symmetric) = transpose _conjugation(::Hermitian) = adjoint diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo)) From fd17a170c99c2bd536dc398aad5ad7b48d4b9c56 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 16 Jun 2025 14:10:01 +0530 Subject: [PATCH 3/6] Use `ImmutableArrays` test helper --- test/symmetric.jl | 11 +---------- test/testhelpers/ImmutableArrays.jl | 3 +++ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/symmetric.jl b/test/symmetric.jl index cd97af56..7177e257 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1351,16 +1351,7 @@ end end @testset "triu/tril with immutable arrays" begin - struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T} - a :: A - end - Base.size(A::ImmutableMatrix) = size(A.a) - Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j) - Base.copy(A::ImmutableMatrix) = A - LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a)) - LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a)) - - A = ImmutableMatrix([1 2; 3 4]) + A = ImmutableArray([1 2; 3 4]) for T in (Symmetric, Hermitian), uplo in (:U, :L) H = T(A, uplo) MH = Matrix(H) diff --git a/test/testhelpers/ImmutableArrays.jl b/test/testhelpers/ImmutableArrays.jl index 8f2d23be..014e8110 100644 --- a/test/testhelpers/ImmutableArrays.jl +++ b/test/testhelpers/ImmutableArrays.jl @@ -28,4 +28,7 @@ AbstractArray{T,N}(A::ImmutableArray{S,N}) where {S,T,N} = ImmutableArray(Abstra Base.copy(A::ImmutableArray) = ImmutableArray(copy(A.data)) Base.zero(A::ImmutableArray) = ImmutableArray(zero(A.data)) +Base.adjoint(A::ImmutableArray) = ImmutableArray(adjoint(A.data)) +Base.transpose(A::ImmutableArray) = ImmutableArray(transpose(A.data)) + end From c16fa00e83df7a1b6bbef399a3b33a6d3d114868 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 16 Jun 2025 14:16:37 +0530 Subject: [PATCH 4/6] Revert `_conjugation` change --- src/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index af3f6d2f..96e9c717 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -332,7 +332,7 @@ Base.dataids(A::HermOrSym) = Base.dataids(parent(A)) Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) -_conjugation(::Symmetric) = transpose +_conjugation(::Union{Symmetric, Hermitian{<:Real}}) = transpose _conjugation(::Hermitian) = adjoint diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo)) From 7d34ec6c2c344984d867df797330a612c5751cbd Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 20 Jun 2025 16:12:41 +0530 Subject: [PATCH 5/6] Preserve wrapper in AbstractMatrix constructor for Adjoint/Transpose --- src/adjtrans.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/adjtrans.jl b/src/adjtrans.jl index 72be05ab..6b346734 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -332,6 +332,10 @@ wrapperop(::Transpose) = transpose _wrapperop(x) = wrapperop(x) _wrapperop(::Adjoint{<:Real}) = transpose +# equivalent to wrapperop, but returns the type of the wrapper +wrappertype(::Adjoint) = Adjoint +wrappertype(::Transpose) = Transpose + # the following fallbacks can be removed if Adjoint/Transpose are restricted to AbstractVecOrMat size(A::AdjOrTrans) = reverse(size(A.parent)) axes(A::AdjOrTrans) = reverse(axes(A.parent)) @@ -391,7 +395,7 @@ similar(A::AdjOrTrans, ::Type{T}) where {T} = similar(A.parent, T, axes(A)) similar(A::AdjOrTrans, ::Type{T}, dims::Dims{N}) where {T,N} = similar(A.parent, T, dims) # AbstractMatrix{T} constructor for adjtrans vector: preserve wrapped type -AbstractMatrix{T}(A::AdjOrTransAbsVec) where {T} = wrapperop(A)(AbstractVector{T}(A.parent)) +AbstractMatrix{T}(A::AdjOrTransAbsVec) where {T} = wrappertype(A)(AbstractVector{T}(A.parent)) # sundry basic definitions parent(A::AdjOrTrans) = A.parent From 88457508d2a44ebbe010fb2169a424a01ea9cd4a Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 22 Sep 2025 13:37:04 +0530 Subject: [PATCH 6/6] Remove `wrappertype` --- src/adjtrans.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/adjtrans.jl b/src/adjtrans.jl index 6b346734..72be05ab 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -332,10 +332,6 @@ wrapperop(::Transpose) = transpose _wrapperop(x) = wrapperop(x) _wrapperop(::Adjoint{<:Real}) = transpose -# equivalent to wrapperop, but returns the type of the wrapper -wrappertype(::Adjoint) = Adjoint -wrappertype(::Transpose) = Transpose - # the following fallbacks can be removed if Adjoint/Transpose are restricted to AbstractVecOrMat size(A::AdjOrTrans) = reverse(size(A.parent)) axes(A::AdjOrTrans) = reverse(axes(A.parent)) @@ -395,7 +391,7 @@ similar(A::AdjOrTrans, ::Type{T}) where {T} = similar(A.parent, T, axes(A)) similar(A::AdjOrTrans, ::Type{T}, dims::Dims{N}) where {T,N} = similar(A.parent, T, dims) # AbstractMatrix{T} constructor for adjtrans vector: preserve wrapped type -AbstractMatrix{T}(A::AdjOrTransAbsVec) where {T} = wrappertype(A)(AbstractVector{T}(A.parent)) +AbstractMatrix{T}(A::AdjOrTransAbsVec) where {T} = wrapperop(A)(AbstractVector{T}(A.parent)) # sundry basic definitions parent(A::AdjOrTrans) = A.parent