From e53f0c380d70bdc9bbe4663b36b2afa824799ed9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 13 May 2025 14:50:05 -0400 Subject: [PATCH 1/6] Make orthnull more customizable --- src/implementations/orthnull.jl | 170 ++++++++++++++++++------------ src/implementations/truncation.jl | 5 +- 2 files changed, 103 insertions(+), 72 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 2680e7cc..68a32d9f 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -1,9 +1,9 @@ # Inputs # ------ -copy_input(::typeof(left_orth), A::AbstractMatrix) = copy_input(qr_compact, A) # do we ever need anything else -copy_input(::typeof(right_orth), A::AbstractMatrix) = copy_input(lq_compact, A) # do we ever need anything else -copy_input(::typeof(left_null), A::AbstractMatrix) = copy_input(qr_null, A) # do we ever need anything else -copy_input(::typeof(right_null), A::AbstractMatrix) = copy_input(lq_null, A) # do we ever need anything else +copy_input(::typeof(left_orth), A) = copy_input(qr_compact, A) # do we ever need anything else +copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever need anything else +copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else +copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else function check_input(::typeof(left_orth!), A::AbstractMatrix, VC) m, n = size(A) @@ -81,7 +81,7 @@ end # Implementation of orth functions # -------------------------------- -function left_orth!(A::AbstractMatrix, VC; trunc=nothing, +function left_orth!(A, VC; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;)) check_input(left_orth!, A, VC) @@ -89,32 +89,40 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing, throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr) - return qr_compact!(A, VC, alg_qr′) + return left_orth_qr!(A, VC, alg_qr) elseif kind == :polar - size(A, 1) >= size(A, 2) || - throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg_polar′ = _select_algorithm(left_polar!, A, alg_polar) - return left_polar!(A, VC, alg_polar′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) - U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′) - return U, lmul!(S, Vᴴ) + return left_orth_polar!(A, VC, alg_polar) elseif kind == :svd - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) - alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc) - return U, lmul!(S, Vᴴ) + return left_orth_svd!(A, VC, alg_svd, trunc) else throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) end end +function left_orth_qr!(A, VC, alg) + alg′ = select_algorithm(qr_compact!, A, alg) + return qr_compact!(A, VC, alg′) +end +function left_orth_polar!(A, VC, alg) + alg′ = select_algorithm(left_polar!, A, alg) + return left_polar!(A, VC, alg′) +end +function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + V, C = VC + S = Diagonal(initialize_output(svd_vals!, A, alg′)) + U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′) + return U, lmul!(S, Vᴴ) +end +function left_orth_svd!(A, VC, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg′) + V, C = VC + S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc) + return U, lmul!(S, Vᴴ) +end -function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, +function right_orth!(A, CVᴴ; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), alg_polar=(;), alg_svd=(;)) check_input(right_orth!, A, CVᴴ) @@ -122,30 +130,38 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end if kind == :lq - alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq) - return lq_compact!(A, CVᴴ, alg_lq′) + return right_orth_lq!(A, CVᴴ, alg) elseif kind == :polar - size(A, 2) >= size(A, 1) || - throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`")) - alg_polar′ = _select_algorithm(right_polar!, A, alg_polar) - return right_polar!(A, CVᴴ, alg_polar′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) - U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′) - return rmul!(U, S), Vᴴ + return right_orth_polar!(A, CVᴴ, alg) elseif kind == :svd - alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) - alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc) - return rmul!(U, S), Vᴴ + return right_orth_svd!(A, CVᴴ, alg, trunc) else throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`")) end end +function right_orth_lq!(A, CVᴴ, alg) + alg′ = select_algorithm(lq_compact!, A, alg) + return lq_compact!(A, CVᴴ, alg′) +end +function right_orth_polar!(A, CVᴴ, alg) + alg′ = select_algorithm(right_polar!, A, alg) + return right_polar!(A, CVᴴ, alg′) +end +function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + C, Vᴴ = CVᴴ + S = Diagonal(initialize_output(svd_vals!, A, alg′)) + U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′) + return rmul!(U, S), Vᴴ +end +function right_orth_svd!(A, CVᴴ, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg′) + C, Vᴴ = CVᴴ + S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc) + return rmul!(U, S), Vᴴ +end # Implementation of null functions # -------------------------------- @@ -159,7 +175,7 @@ function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothi return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc end -function left_null!(A::AbstractMatrix, N; trunc=nothing, +function left_null!(A, N; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_svd=(;)) check_input(left_null!, A, N) @@ -167,26 +183,33 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing, throw(ArgumentError("truncation not supported for left_null with kind=$kind")) end if kind == :qr - alg_qr′ = _select_algorithm(qr_null!, A, alg_qr) - return qr_null!(A, N, alg_qr′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) - U, _, _ = svd_full!(A, alg_svd′) - (m, n) = size(A) - return copy!(N, view(U, 1:m, (n + 1):m)) + left_null_qr!(A, N, alg_qr) elseif kind == :svd - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) - U, S, _ = svd_full!(A, alg_svd′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(left_null!, (U, S), trunc′) + left_null_svd!(A, N, alg_svd, trunc) else throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) end end +function left_null_qr!(A, N, alg) + alg′ = select_algorithm(qr_null!, A, alg) + return qr_null!(A, N, alg′) +end +function left_null_svd!(A, N, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_full!, A, alg) + U, _, _ = svd_full!(A, alg′) + (m, n) = size(A) + return copy!(N, view(U, 1:m, (n + 1):m)) +end +function left_null_svd!(A, N, alg, trunc) + alg′ = select_algorithm(svd_full!, A, alg) + U, S, _ = svd_full!(A, alg′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(left_null!, (U, S), trunc′) +end -function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, +function right_null!(A, Nᴴ; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), alg_svd=(;)) check_input(right_null!, A, Nᴴ) @@ -194,21 +217,28 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_null with kind=$kind")) end if kind == :lq - alg_lq′ = _select_algorithm(lq_null!, A, alg_lq) - return lq_null!(A, Nᴴ, alg_lq′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) - _, _, Vᴴ = svd_full!(A, alg_svd′) - (m, n) = size(A) - return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) + return right_null_lq!(A, Nᴴ, alg_lq) elseif kind == :svd - alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) - _, S, Vᴴ = svd_full!(A, alg_svd′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(right_null!, (S, Vᴴ), trunc′) + return right_null_svd!(A, Nᴴ, alg_svd) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end end +function right_null_lq!(A, Nᴴ, alg) + alg′ = select_algorithm(lq_null!, A, alg) + return lq_null!(A, Nᴴ, alg′) +end +function right_null_svd!(A, Nᴴ, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_full!, A, alg) + _, _, Vᴴ = svd_full!(A, alg′) + (m, n) = size(A) + return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) +end +function right_null_svd!(A, Nᴴ, alg, trunc) + alg′ = select_algorithm(svd_full!, A, alg) + _, S, Vᴴ = svd_full!(A, alg′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(right_null!, (S, Vᴴ), trunc′) +end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 95a5c002..10ed3a7a 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -98,8 +98,9 @@ struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy components::T end -TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) = - TruncationIntersection((trunc, truncs...)) +function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) + return TruncationIntersection((trunc, truncs...)) +end function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) return TruncationIntersection((trunc1, trunc2)) From 4448b07a38f29c8b639831203ef1d155e00c3b00 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 13 May 2025 14:55:37 -0400 Subject: [PATCH 2/6] Fix typo --- src/implementations/orthnull.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 68a32d9f..b513ac4b 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -130,11 +130,11 @@ function right_orth!(A, CVᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end if kind == :lq - return right_orth_lq!(A, CVᴴ, alg) + return right_orth_lq!(A, CVᴴ, alg_lq) elseif kind == :polar - return right_orth_polar!(A, CVᴴ, alg) + return right_orth_polar!(A, CVᴴ, alg_polar) elseif kind == :svd - return right_orth_svd!(A, CVᴴ, alg, trunc) + return right_orth_svd!(A, CVᴴ, alg_svd, trunc) else throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`")) end From b80118a9d78d5acd80498e3f6a802252db270342 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 14 May 2025 10:18:04 -0400 Subject: [PATCH 3/6] Fix some tests --- src/implementations/orthnull.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index b513ac4b..4f50f283 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -115,7 +115,7 @@ function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing) end function left_orth_svd!(A, VC, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) - alg_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg′) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc) @@ -156,7 +156,7 @@ function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing) end function right_orth_svd!(A, CVᴴ, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) - alg_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg′) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc) From 7cc51bf6a780404f32ec48477258d5540c72c5ff Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 14 May 2025 10:30:14 -0400 Subject: [PATCH 4/6] Properly forward truncation to right_null --- src/implementations/orthnull.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 4f50f283..ab18f689 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -219,7 +219,7 @@ function right_null!(A, Nᴴ; trunc=nothing, if kind == :lq return right_null_lq!(A, Nᴴ, alg_lq) elseif kind == :svd - return right_null_svd!(A, Nᴴ, alg_svd) + return right_null_svd!(A, Nᴴ, alg_svd, trunc) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end From c030ce5b848b7132e6c80949f7ad97490b692621 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 16 May 2025 18:02:12 -0400 Subject: [PATCH 5/6] orth for non-AbstractMatrix --- src/implementations/orthnull.jl | 26 ++++++++++++++++++++ src/interface/orthnull.jl | 16 ++++++------ test/orthnull.jl | 43 +++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index ab18f689..f53055db 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -107,6 +107,12 @@ function left_orth_polar!(A, VC, alg) return left_polar!(A, VC, alg′) end function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + U, S, Vᴴ = svd_compact!(A, alg′) + V, C = VC + return copy!(V, U), mul!(C, S, Vᴴ) +end +function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_compact!, A, alg) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg′)) @@ -114,6 +120,13 @@ function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing) return U, lmul!(S, Vᴴ) end function left_orth_svd!(A, VC, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + U, S, Vᴴ = svd_trunc!(A, alg_trunc) + V, C = VC + return copy!(V, U), mul!(C, S, Vᴴ) +end +function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) V, C = VC @@ -148,6 +161,12 @@ function right_orth_polar!(A, CVᴴ, alg) return right_polar!(A, CVᴴ, alg′) end function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + U, S, Vᴴ′ = svd_compact!(A, alg′) + C, Vᴴ = CVᴴ + return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) +end +function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing=nothing) alg′ = select_algorithm(svd_compact!, A, alg) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg′)) @@ -155,6 +174,13 @@ function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing) return rmul!(U, S), Vᴴ end function right_orth_svd!(A, CVᴴ, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + U, S, Vᴴ′ = svd_trunc!(A, alg_trunc) + C, Vᴴ = CVᴴ + return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) +end +function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) C, Vᴴ = CVᴴ diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 4b00ec66..efc678ee 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -69,10 +69,10 @@ See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [ """ function left_orth end function left_orth! end -function left_orth!(A::AbstractMatrix; kwargs...) +function left_orth!(A; kwargs...) return left_orth!(A, initialize_output(left_orth!, A); kwargs...) end -function left_orth(A::AbstractMatrix; kwargs...) +function left_orth(A; kwargs...) return left_orth!(copy_input(left_orth, A); kwargs...) end @@ -128,10 +128,10 @@ See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`r """ function right_orth end function right_orth! end -function right_orth!(A::AbstractMatrix; kwargs...) +function right_orth!(A; kwargs...) return right_orth!(A, initialize_output(right_orth!, A); kwargs...) end -function right_orth(A::AbstractMatrix; kwargs...) +function right_orth(A; kwargs...) return right_orth!(copy_input(right_orth, A); kwargs...) end @@ -180,10 +180,10 @@ See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), [ """ function left_null end function left_null! end -function left_null!(A::AbstractMatrix; kwargs...) +function left_null!(A; kwargs...) return left_null!(A, initialize_output(left_null!, A); kwargs...) end -function left_null(A::AbstractMatrix; kwargs...) +function left_null(A; kwargs...) return left_null!(copy_input(left_null, A); kwargs...) end @@ -230,9 +230,9 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), [`r """ function right_null end function right_null! end -function right_null!(A::AbstractMatrix; kwargs...) +function right_null!(A; kwargs...) return right_null!(A, initialize_output(right_null!, A); kwargs...) end -function right_null(A::AbstractMatrix; kwargs...) +function right_null(A; kwargs...) return right_null!(copy_input(right_null, A); kwargs...) end diff --git a/test/orthnull.jl b/test/orthnull.jl index 0823d3c5..2eaf1695 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -5,6 +5,41 @@ using StableRNGs using LinearAlgebra: LinearAlgebra, I using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow +# Used to test non-AbstractMatrix codepaths. +struct LinearMap{P<:AbstractMatrix} + parent::P +end +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, + initialize_output +function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap) + return LinearMap(copy_input(qr_compact, A.parent)) +end +function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap) + return LinearMap(copy_input(lq_compact, A.parent)) +end +function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap) + return initialize_output(left_orth!, A.parent) +end +function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap) + return initialize_output(right_orth!, A.parent) +end +function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC) + return check_input(left_orth!, A.parent, VC) +end +function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC) + return check_input(right_orth!, A.parent, VC) +end +function MatrixAlgebraKit.default_svd_algorithm(A::LinearMap) + return default_svd_algorithm(A.parent) +end +function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap, + alg::LAPACK_SVDAlgorithm) + return initialize_output(svd_compact!, A.parent, alg) +end +function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::LAPACK_SVDAlgorithm) + return svd_compact!(A.parent, USVᴴ, alg) +end + @testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -23,6 +58,10 @@ using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow @test N' * N ≈ I @test V * V' + N * N' ≈ I + M = LinearMap(A) + V, C = @constinferred left_orth(M; kind=:svd) + @test V * C ≈ A + if m > n nullity = 5 V, C = @constinferred left_orth(A) @@ -162,6 +201,10 @@ end @test Nᴴ * Nᴴ' ≈ I @test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I + M = LinearMap(A) + C, Vᴴ = @constinferred right_orth(M; kind=:svd) + @test C * Vᴴ ≈ A + Ac = similar(A) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) From 8887591e4ec2f3eafdbf446759a08eb1dbdd8771 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 16 May 2025 18:20:21 -0400 Subject: [PATCH 6/6] More robust non-AbstractMatrix test --- test/orthnull.jl | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/test/orthnull.jl b/test/orthnull.jl index 2eaf1695..7f497eaf 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -2,42 +2,52 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, I +using LinearAlgebra: LinearAlgebra, I, mul! using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, + initialize_output # Used to test non-AbstractMatrix codepaths. struct LinearMap{P<:AbstractMatrix} parent::P end -using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, - initialize_output +Base.parent(A::LinearMap) = getfield(A, :parent) +function Base.copy!(dest::LinearMap, src::LinearMap) + copy!(parent(dest), parent(src)) + return dest +end +function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap) + mul!(parent(C), parent(A), parent(B)) + return C +end + function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap) - return LinearMap(copy_input(qr_compact, A.parent)) + return LinearMap(copy_input(qr_compact, parent(A))) end function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap) - return LinearMap(copy_input(lq_compact, A.parent)) + return LinearMap(copy_input(lq_compact, parent(A))) end function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap) - return initialize_output(left_orth!, A.parent) + return LinearMap.(initialize_output(left_orth!, parent(A))) end function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap) - return initialize_output(right_orth!, A.parent) + return LinearMap.(initialize_output(right_orth!, parent(A))) end function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC) - return check_input(left_orth!, A.parent, VC) + return check_input(left_orth!, parent(A), parent.(VC)) end function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC) - return check_input(right_orth!, A.parent, VC) + return check_input(right_orth!, parent(A), parent.(VC)) end function MatrixAlgebraKit.default_svd_algorithm(A::LinearMap) - return default_svd_algorithm(A.parent) + return default_svd_algorithm(parent(A)) end function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap, alg::LAPACK_SVDAlgorithm) - return initialize_output(svd_compact!, A.parent, alg) + return LinearMap.(initialize_output(svd_compact!, parent(A), alg)) end function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::LAPACK_SVDAlgorithm) - return svd_compact!(A.parent, USVᴴ, alg) + return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg)) end @testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, @@ -59,8 +69,8 @@ end @test V * V' + N * N' ≈ I M = LinearMap(A) - V, C = @constinferred left_orth(M; kind=:svd) - @test V * C ≈ A + VM, CM = @constinferred left_orth(M; kind=:svd) + @test parent(VM) * parent(CM) ≈ A if m > n nullity = 5 @@ -202,8 +212,8 @@ end @test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I M = LinearMap(A) - C, Vᴴ = @constinferred right_orth(M; kind=:svd) - @test C * Vᴴ ≈ A + CM, VMᴴ = @constinferred right_orth(M; kind=:svd) + @test parent(CM) * parent(VMᴴ) ≈ A Ac = similar(A) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))