From ac34ab75f4513f0ab8c30f56cf0e7c7a2c70d276 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 8 Apr 2025 21:57:21 -0400 Subject: [PATCH 01/14] [WIP] More general orth truncation --- Project.toml | 2 +- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 2 +- src/algorithms.jl | 2 +- src/implementations/orthnull.jl | 29 +++++++++++------------- src/interface/orthnull.jl | 2 +- src/pullbacks/polar.jl | 2 +- test/chainrules.jl | 2 +- test/orthnull.jl | 10 ++++---- 8 files changed, 24 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index d8ea7ab4..def62633 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.1.1" +version = "0.2.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 5ef1fd9b..e21842c6 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -176,4 +176,4 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, return PWᴴ, right_polar_pullback end -end \ No newline at end of file +end diff --git a/src/algorithms.jl b/src/algorithms.jl index ea260aa2..fc247317 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -178,4 +178,4 @@ macro check_size(x, sz, size=:size) string($sz) szx == $sz || throw(DimensionMismatch($err)) end) -end \ No newline at end of file +end diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index e134121b..b0a9d1af 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -81,37 +81,34 @@ end # Implementation of orth functions # -------------------------------- -function left_orth!(A::AbstractMatrix, VC; kwargs...) +function left_orth!(A::AbstractMatrix, VC; alg=nothing, trunc=nothing, + kind=isnothing(trunc) ? :qrpos : :svd, qr_kwargs=(;), polar_kwargs=(;), + svd_kwargs=(;)) check_input(left_orth!, A, VC) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg = get(kwargs, :alg, select_algorithm(qr_compact!, A)) + alg = @something alg select_algorithm(qr_compact!, A; qr_kwargs...) return qr_compact!(A, VC, alg) elseif kind == :qrpos - alg = get(kwargs, :alg, select_algorithm(qr_compact!, A; positive=true)) + alg = @something alg select_algorithm(qr_compact!, A; positive=true, qr_kwargs...) return qr_compact!(A, VC, alg) elseif kind == :polar size(A, 1) >= size(A, 2) || throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg = get(kwargs, :alg, select_algorithm(left_polar!, A)) + alg = @something alg select_algorithm(left_polar!, A; polar_kwargs...) return left_polar!(A, VC, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_compact!, A)) + elseif kind == :svd && isnothing(trunc) + alg = @something alg select_algorithm(svd_compact!, A; svd_kwargs...) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg)) U, S, Vᴴ = svd_compact!(A, (V, S, C), alg) return U, lmul!(S, Vᴴ) elseif kind == :svd - alg_svd = select_algorithm(svd_compact!, A) - trunc = TruncationKeepAbove(atol, rtol) - alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc)) + alg = @something alg select_algorithm(svd_trunc!, A; trunc, svd_kwargs...) V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_svd)) + S = Diagonal(initialize_output(svd_vals!, A, alg.alg)) U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg) return U, lmul!(S, Vᴴ) else @@ -215,4 +212,4 @@ function right_null!(A::AbstractMatrix, Nᴴ; kwargs...) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end -end \ No newline at end of file +end diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index de7c3153..78a60568 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -224,4 +224,4 @@ function right_null!(A::AbstractMatrix; kwargs...) end function right_null(A::AbstractMatrix; kwargs...) return right_null!(copy_input(right_null, A); kwargs...) -end \ No newline at end of file +end diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index ed6fc17d..2eea389e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -58,4 +58,4 @@ function right_polar_pullback!(ΔA::AbstractMatrix, PWᴴ, ΔPWᴴ) ΔA .+= PΔWᴴ end return ΔA -end \ No newline at end of file +end diff --git a/test/chainrules.jl b/test/chainrules.jl index 47de1c04..a93c5270 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -356,4 +356,4 @@ end test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ, atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) end -end \ No newline at end of file +end diff --git a/test/orthnull.jl b/test/orthnull.jl index 7b926b59..1430224e 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -29,7 +29,7 @@ @test V2 * V2' + N2 * N2' ≈ I atol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); atol=atol) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol)) N2 = @constinferred left_null!(copy!(Ac, A), N; atol=atol) @test V2 !== V @test C2 !== C @@ -41,7 +41,7 @@ @test V2 * V2' + N2 * N2' ≈ I rtol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); rtol=rtol) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; rtol=rtol)) N2 = @constinferred left_null!(copy!(Ac, A), N; rtol=rtol) @test V2 !== V @test C2 !== C @@ -70,7 +70,7 @@ # with kind and tol kwargs if kind == :svd V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind, - atol=atol) + trunc=(; atol=atol)) N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, atol=atol) @test V2 !== V @test C2 !== C @@ -82,7 +82,7 @@ @test V2 * V2' + N2 * N2' ≈ I V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind, - rtol=rtol) + trunc=(; rtol=rtol)) N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, rtol=rtol) @test V2 !== V @test C2 !== C @@ -209,4 +209,4 @@ end end end end -end \ No newline at end of file +end From e67b7b1e82aa84c8eb1eb5c4d36487aeb15b047e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 18:37:44 -0400 Subject: [PATCH 02/14] More customization of algorithms --- src/implementations/orthnull.jl | 36 ++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index b0a9d1af..a1c59868 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -81,35 +81,43 @@ end # Implementation of orth functions # -------------------------------- -function left_orth!(A::AbstractMatrix, VC; alg=nothing, trunc=nothing, - kind=isnothing(trunc) ? :qrpos : :svd, qr_kwargs=(;), polar_kwargs=(;), - svd_kwargs=(;)) +function left_orth!(A::AbstractMatrix, VC; trunc=nothing, + kind=isnothing(trunc) ? :qrpos : :svd, alg_qr=(;), alg_polar=(;), + alg_svd=(;)) check_input(left_orth!, A, VC) if !isnothing(trunc) && kind != :svd throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg = @something alg select_algorithm(qr_compact!, A; qr_kwargs...) + alg_qr = alg_qr isa NamedTuple ? + select_algorithm(qr_compact!, A; alg_qr...) : alg_qr return qr_compact!(A, VC, alg) elseif kind == :qrpos - alg = @something alg select_algorithm(qr_compact!, A; positive=true, qr_kwargs...) - return qr_compact!(A, VC, alg) + alg_qr = alg_qr isa NamedTuple ? + select_algorithm(qr_compact!, A; positive=true, + alg_qr...) : alg_qr + return qr_compact!(A, VC, alg_qr) elseif kind == :polar size(A, 1) >= size(A, 2) || throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg = @something alg select_algorithm(left_polar!, A; polar_kwargs...) - return left_polar!(A, VC, alg) + alg_polar = alg_polar isa NamedTuple ? + select_algorithm(left_polar!, A; alg_polar...) : + alg_polar + return left_polar!(A, VC, alg_polar) elseif kind == :svd && isnothing(trunc) - alg = @something alg select_algorithm(svd_compact!, A; svd_kwargs...) + alg_svd = alg_svd isa NamedTuple ? + select_algorithm(svd_compact!, A; alg_svd...) : alg_svd V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg)) - U, S, Vᴴ = svd_compact!(A, (V, S, C), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd)) + U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd) return U, lmul!(S, Vᴴ) elseif kind == :svd - alg = @something alg select_algorithm(svd_trunc!, A; trunc, svd_kwargs...) + alg_svd = alg_svd isa NamedTuple ? + select_algorithm(svd_compact!, A; alg_svd...) : alg_svd + alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd) V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg.alg)) - U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc) return U, lmul!(S, Vᴴ) else throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) From 2d9499f3189abcb0133b3f9347f5a8bf4cd1dfad Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 09:10:29 -0400 Subject: [PATCH 03/14] Simplify handling positive QR --- src/implementations/orthnull.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index a1c59868..17b543dc 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -82,7 +82,7 @@ end # Implementation of orth functions # -------------------------------- function left_orth!(A::AbstractMatrix, VC; trunc=nothing, - kind=isnothing(trunc) ? :qrpos : :svd, alg_qr=(;), alg_polar=(;), + kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;)) check_input(left_orth!, A, VC) if !isnothing(trunc) && kind != :svd @@ -92,11 +92,6 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing, alg_qr = alg_qr isa NamedTuple ? select_algorithm(qr_compact!, A; alg_qr...) : alg_qr return qr_compact!(A, VC, alg) - elseif kind == :qrpos - alg_qr = alg_qr isa NamedTuple ? - select_algorithm(qr_compact!, A; positive=true, - alg_qr...) : alg_qr - return qr_compact!(A, VC, alg_qr) elseif kind == :polar size(A, 1) >= size(A, 2) || throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) From 5891f8108e9a152a42b7d0216434f49bd5b4818c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 09:10:48 -0400 Subject: [PATCH 04/14] Format --- src/implementations/orthnull.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 17b543dc..518bba59 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -82,7 +82,8 @@ end # Implementation of orth functions # -------------------------------- function left_orth!(A::AbstractMatrix, VC; trunc=nothing, - kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_polar=(;), + kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), + alg_polar=(;), alg_svd=(;)) check_input(left_orth!, A, VC) if !isnothing(trunc) && kind != :svd From 63a7f01a07e529f16537fb67bf84afff933ae5f3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 09:11:42 -0400 Subject: [PATCH 05/14] Format --- src/implementations/orthnull.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 518bba59..c1d0981c 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -83,8 +83,7 @@ end # -------------------------------- function left_orth!(A::AbstractMatrix, VC; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), - alg_polar=(;), - alg_svd=(;)) + alg_polar=(;), alg_svd=(;)) check_input(left_orth!, A, VC) if !isnothing(trunc) && kind != :svd throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) From 2783bf51cfe48d4b140bd7273457034eea905fbd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 11:20:46 -0400 Subject: [PATCH 06/14] Refactor --- src/implementations/orthnull.jl | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index c1d0981c..4cc84a93 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -79,6 +79,13 @@ function initialize_output(::typeof(right_null!), A::AbstractMatrix) return Nᴴ end +function algorithm_or_select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm) + return alg +end +function algorithm_or_select_algorithm(f, A::AbstractMatrix, kwargs::NamedTuple) + return select_algorithm(f, A; kwargs...) +end + # Implementation of orth functions # -------------------------------- function left_orth!(A::AbstractMatrix, VC; trunc=nothing, @@ -89,27 +96,22 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing, throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg_qr = alg_qr isa NamedTuple ? - select_algorithm(qr_compact!, A; alg_qr...) : alg_qr - return qr_compact!(A, VC, alg) + alg_qr′ = algorithm_or_select_algorithm(qr_compact!, A, alg_qr) + return qr_compact!(A, VC, alg_qr′) elseif kind == :polar size(A, 1) >= size(A, 2) || throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg_polar = alg_polar isa NamedTuple ? - select_algorithm(left_polar!, A; alg_polar...) : - alg_polar - return left_polar!(A, VC, alg_polar) + alg_polar′ = algorithm_or_select_algorithm(left_polar!, A, alg_polar) + return left_polar!(A, VC, alg_polar′) elseif kind == :svd && isnothing(trunc) - alg_svd = alg_svd isa NamedTuple ? - select_algorithm(svd_compact!, A; alg_svd...) : alg_svd + alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_svd)) - U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) + U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′) return U, lmul!(S, Vᴴ) elseif kind == :svd - alg_svd = alg_svd isa NamedTuple ? - select_algorithm(svd_compact!, A; alg_svd...) : alg_svd - alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd) + alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) + alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc) From 19fd99489bd5111d565ab98ad70e8360e02b8db8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 12:07:08 -0400 Subject: [PATCH 07/14] Fix tests --- test/orthnull.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/orthnull.jl b/test/orthnull.jl index ab943005..591c292c 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -58,7 +58,7 @@ using LinearAlgebra: LinearAlgebra, I @test N2' * N2 ≈ I @test V2 * V2' + N2 * N2' ≈ I - for kind in (:qr, :qrpos, :polar, :svd) # explicit kind kwarg + for kind in (:qr, :polar, :svd) # explicit kind kwarg m < n && kind == :polar && continue V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind) @test V2 === V @@ -100,9 +100,9 @@ using LinearAlgebra: LinearAlgebra, I @test V2 * V2' + N2 * N2' ≈ I else @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind, - atol=atol) + trunc=(; atol=atol)) @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind, - rtol=rtol) + trunc=(; rtol=rtol)) @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, atol=atol) @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, rtol=rtol) end From e0dd0ef88c77dcf6b9a76333b954dc7a8012ff07 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 13:16:48 -0400 Subject: [PATCH 08/14] Update right_orth --- src/implementations/orthnull.jl | 39 +++++++---------- src/interface/orthnull.jl | 78 +++++++++++++++++---------------- test/chainrules.jl | 12 ++--- test/orthnull.jl | 14 +++--- 4 files changed, 70 insertions(+), 73 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 4cc84a93..3c2a7856 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -121,38 +121,33 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing, end end -function right_orth!(A::AbstractMatrix, CVᴴ; kwargs...) +function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, + kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), + alg_polar=(;), alg_svd=(;)) check_input(right_orth!, A, CVᴴ) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end if kind == :lq - alg = get(kwargs, :alg, select_algorithm(lq_compact!, A)) - return lq_compact!(A, CVᴴ, alg) - elseif kind == :lqpos - alg = get(kwargs, :alg, select_algorithm(lq_compact!, A; positive=true)) - return lq_compact!(A, CVᴴ, alg) + alg_lq′ = algorithm_or_select_algorithm(lq_compact!, A, alg_lq) + return lq_compact!(A, CVᴴ, alg_lq′) elseif kind == :polar size(A, 2) >= size(A, 1) || throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`")) - alg = get(kwargs, :alg, select_algorithm(right_polar!, A)) - return right_polar!(A, CVᴴ, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_compact!, A)) + alg_polar′ = algorithm_or_select_algorithm(right_polar!, A, alg_polar) + return right_polar!(A, CVᴴ, alg_polar′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg)) - U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) + U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′) return rmul!(U, S), Vᴴ elseif kind == :svd - alg_svd = select_algorithm(svd_compact!, A) - trunc = TruncationKeepAbove(atol, rtol) - alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc)) + alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) + alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_svd)) - U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc) return rmul!(U, S), Vᴴ else throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`")) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 236184bd..ea2f218f 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -19,43 +19,44 @@ end # Orth functions # -------------- """ - left_orth(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> V, C - left_orth!(A, [VC]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> V, C + left_orth(A; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C + left_orth!(A, [VC]; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C Compute an orthonormal basis `V` for the image of the matrix `A` of size `(m, n)`, as well as a matrix `C` (the corestriction) such that `A` factors as `A = V * C`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the +that should be used to factor `A`, whereas `trunc` can be used to control the precision in determining the rank of `A` via its singular values. This is a high-level wrapper and will use one of the decompositions -`qr!`, `svd!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled +`qr_compact!`, `svd_compact!`/`svd_trunc!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled by the keyword arguments. When `kind` is provided, its possible values are -* `kind == :qrpos`: `V` and `C` are computed using the positive QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to +* `kind == :qr`: `V` and `C` are computed using the QR decomposition. + This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to `qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` -* `kind == :qr`: `V` and `C` are computed using the QR decomposition, - This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to - `qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A)` - * `kind == :polar`: `V` and `C` are computed using the polar decomposition, - This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to + This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to `left_polar!(A, [VC], alg)` with a default value `alg = select_algorithm(left_polar!, A)` -* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_trunc!`, - where `V` will contain the left singular vectors corresponding to the singular values that - are larger than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. - `C` is computed as the product of the singular values and the right singular vectors, - i.e. with `U, S, Vᴴ = svd_trunc!(A)`, we have `V = U` and `C = S * Vᴴ`. +* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_compact!` + if no truncation is specified through the `trunc` keyword argument or `svd_trunc!` + if truncation is specified through the `trunc` keyword argument. + `V` will contain the left singular vectors and `C` is computed as the product of the singular + values and the right singular vectors, i.e. with `U, S, Vᴴ = svd(A)`, we have + `V = U` and `C = S * Vᴴ`. -When `kind` is not provided, the default value is `:qrpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +for backend factorizations through the `alg_qr`, `alg_polar`, and `alg_svd` keyword arguments, +which will only be used if the corresponding factorization is called based on the other inputs. +If NamedTuples are passed as `alg_qr`, `alg_polar`, or `alg_svd`, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will +be used. !!! note The bang method `left_orth!` optionally accepts the output structure and possibly destroys @@ -80,37 +81,38 @@ end Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. for the image of `adjoint(A)`, as well as a matrix `C` such that `A = C * Vᴴ`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the +that should be used to factor `A`, whereas `trunc` can be used to control the precision in determining the rank of `A` via its singular values. This is a high-level wrapper and will use call one of the decompositions -`qr!`, `svd!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled -by the keyword arguments. +`lq_compact!`, `svd_compact!`/`svd_trunc!`, and `right_polar!` to compute the +orthogonal basis `V`, as controlled by the keyword arguments. When `kind` is provided, its possible values are -* `kind == :lqpos`: `C` and `Vᴴ` are computed using the positive QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to - `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` - * `kind == :lq`: `C` and `Vᴴ` are computed using the QR decomposition, - This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to - `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A))` + This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to + `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` * `kind == :polar`: `C` and `Vᴴ` are computed using the polar decomposition, - This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to - `right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A))` + This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to + `right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A)` -* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_trunc!`, - where `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular - values that are larger than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. - `C` is computed as the product of the singular values and the right singular vectors, - i.e. with `U, S, Vᴴ = svd_trunc!(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`. +* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_compact!` + if no truncation is specified through the `trunc` keyword argument or `svd_trunc!` + if truncation is specified through the `trunc` keyword argument. + `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular + values and `C` is computed as the product of the singular values and the right singular vectors, + i.e. with `U, S, Vᴴ = svd(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`. -When `kind` is not provided, the default value is `:lqpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +for backend factorizations through the `alg_lq`, `alg_polar`, and `alg_svd` keyword arguments, +which will only be used if the corresponding factorization is called based on the other inputs. +If NamedTuples are passed as `alg_lq`, `alg_polar`, or `alg_svd`, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_lq` defaults to `(; positive=true)` so that by default a positive QR decomposition will +be used. !!! note The bang method `right_orth!` optionally accepts the output structure and possibly destroys diff --git a/test/chainrules.jl b/test/chainrules.jl index ab895a4a..a48d2aae 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -338,26 +338,26 @@ end config = Zygote.ZygoteRuleConfig() test_rrule(config, left_orth, A; atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - test_rrule(config, left_orth, A; fkwargs=(; kind=:qrpos), + test_rrule(config, left_orth, A; fkwargs=(; kind=:qr), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) m >= n && test_rrule(config, left_orth, A; fkwargs=(; kind=:polar), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - ΔN = left_orth(A; kind=:qrpos)[1] * randn(rng, T, min(m, n), m - min(m, n)) - test_rrule(config, left_null, A; fkwargs=(; kind=:qrpos), output_tangent=ΔN, + ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + test_rrule(config, left_null, A; fkwargs=(; kind=:qr), output_tangent=ΔN, atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) test_rrule(config, right_orth, A; atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - test_rrule(config, right_orth, A; fkwargs=(; kind=:lqpos), + test_rrule(config, right_orth, A; fkwargs=(; kind=:lq), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) m <= n && test_rrule(config, right_orth, A; fkwargs=(; kind=:polar), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lqpos)[2] - test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ, + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] + test_rrule(config, right_null, A; fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) end end diff --git a/test/orthnull.jl b/test/orthnull.jl index 591c292c..048167e3 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -141,7 +141,7 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I atol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); atol=atol) + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; atol=atol) @test C2 !== C @test Vᴴ2 !== Vᴴ @@ -153,7 +153,7 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I rtol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); rtol=rtol) + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; rtol=rtol) @test C2 !== C @test Vᴴ2 !== Vᴴ @@ -164,7 +164,7 @@ end @test Nᴴ2 * Nᴴ2' ≈ I @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - for kind in (:lq, :lqpos, :polar, :svd) + for kind in (:lq, :polar, :svd) n < m && kind == :polar && continue C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind) @test C2 === C @@ -181,7 +181,7 @@ end if kind == :svd C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - atol=atol) + trunc=(; atol=atol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, atol=atol) @test C2 !== C @test Vᴴ2 !== Vᴴ @@ -193,7 +193,7 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - rtol=rtol) + trunc=(; rtol=rtol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, rtol=rtol) @test C2 !== C @test Vᴴ2 !== Vᴴ @@ -205,9 +205,9 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I else @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - atol=atol) + trunc=(; atol=atol)) @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - rtol=rtol) + trunc=(; rtol=rtol)) @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind, atol=atol) @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind, From 45bce81198d13a6dda0f6c371f284ca3e084efc3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 16:25:38 -0400 Subject: [PATCH 09/14] Update null and docs --- src/implementations/orthnull.jl | 83 +++++++++++++++++-------------- src/implementations/truncation.jl | 4 +- src/interface/orthnull.jl | 65 +++++++++++++----------- test/orthnull.jl | 30 ++++++----- 4 files changed, 103 insertions(+), 79 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 3c2a7856..849b2202 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -156,59 +156,70 @@ end # Implementation of null functions # -------------------------------- -function left_null!(A::AbstractMatrix, N; kwargs...) +function null_truncation_strategy(; atol=nothing, rtol=nothing, maxrank=nothing) + if isnothing(maxrank) && isnothing(atol) && isnothing(rtol) + return NoTruncation() + end + atol = @something atol 0 + rtol = @something rtol 0 + trunc = TruncationKeepBelow(atol, rtol) + return !isnothing(maxrank) ? trunc & truncrank(maxrank; rev=false) : trunc +end + +function left_null!(A::AbstractMatrix, N; trunc=nothing, + kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), + alg_svd=(;)) check_input(left_null!, A, N) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for left_null with kind=$kind")) end if kind == :qr - alg = get(kwargs, :alg, select_algorithm(qr_null!, A)) - return qr_null!(A, N, alg) - elseif kind == :qrpos - alg = get(kwargs, :alg, select_algorithm(qr_null!, A; positive=true)) - return qr_null!(A, N, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - U, _, _ = svd_full!(A, alg) + alg_qr′ = algorithm_or_select_algorithm(qr_null!, A, alg_qr) + return qr_null!(A, N, alg_qr′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + U, _, _ = svd_full!(A, alg_svd′) + (m, n) = size(A) + return copy!(N, view(U, 1:m, (n + 1):m)) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + U, _, _ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(N, view(U, 1:m, (n + 1):m)) elseif kind == :svd - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - U, S, _ = svd_full!(A, alg) - trunc = TruncationKeepBelow(atol, rtol) - return truncate!(left_null!, (U, S), trunc) + alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + U, S, _ = svd_full!(A, alg_svd′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(left_null!, (U, S), trunc′) else throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) end end -function right_null!(A::AbstractMatrix, Nᴴ; kwargs...) +function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, + kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), + alg_svd=(;)) check_input(right_null!, A, Nᴴ) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for right_null with kind=$kind")) end if kind == :lq - alg = get(kwargs, :alg, select_algorithm(lq_null!, A)) - return lq_null!(A, Nᴴ, alg) - elseif kind == :lqpos - alg = get(kwargs, :alg, select_algorithm(lq_null!, A; positive=true)) - return lq_null!(A, Nᴴ, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - _, _, Vᴴ = svd_full!(A, alg) + alg_lq′ = algorithm_or_select_algorithm(lq_null!, A, alg_lq) + return lq_null!(A, Nᴴ, alg_lq′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + _, _, Vᴴ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) elseif kind == :svd - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - _, S, Vᴴ = svd_full!(A, alg) - trunc = TruncationKeepBelow(atol, rtol) - return truncate!(right_null!, (S, Vᴴ), trunc) + alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + _, S, Vᴴ = svd_full!(A, alg_svd′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(right_null!, (S, Vᴴ), trunc′) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 6088ed7e..0257c8e1 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -69,11 +69,11 @@ TruncationKeepBelow(atol::Real, rtol::Real) = TruncationKeepBelow(promote(atol, # TODO: better names for these functions of the above types """ - truncrank(howmany::Int, by=abs, rev=true) + truncrank(howmany::Int; by=abs, rev=true) Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true. """ -truncrank(howmany::Int, by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev) +truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev) """ trunctol(atol::Real) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index ea2f218f..8c858fbc 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -28,6 +28,9 @@ The keyword argument `kind` can be used to specify the specific orthogonal decom that should be used to factor `A`, whereas `trunc` can be used to control the precision in determining the rank of `A` via its singular values. +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxrank`. + This is a high-level wrapper and will use one of the decompositions `qr_compact!`, `svd_compact!`/`svd_trunc!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled by the keyword arguments. @@ -75,8 +78,8 @@ function left_orth(A::AbstractMatrix; kwargs...) end """ - right_orth(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> C, Vᴴ - right_orth!(A, [CVᴴ]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> C, Vᴴ + right_orth(A; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ + right_orth!(A, [CVᴴ]; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. for the image of `adjoint(A)`, as well as a matrix `C` such that `A = C * Vᴴ`. @@ -84,6 +87,9 @@ The keyword argument `kind` can be used to specify the specific orthogonal decom that should be used to factor `A`, whereas `trunc` can be used to control the precision in determining the rank of `A` via its singular values. +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxrank`. + This is a high-level wrapper and will use call one of the decompositions `lq_compact!`, `svd_compact!`/`svd_trunc!`, and `right_polar!` to compute the orthogonal basis `V`, as controlled by the keyword arguments. @@ -109,7 +115,7 @@ When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm for backend factorizations through the `alg_lq`, `alg_polar`, and `alg_svd` keyword arguments, which will only be used if the corresponding factorization is called based on the other inputs. -If NamedTuples are passed as `alg_lq`, `alg_polar`, or `alg_svd`, a default algorithm is chosen +If `alg_lq`, `alg_polar`, or `alg_svd` are NamedTuples, a default algorithm is chosen with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. `alg_lq` defaults to `(; positive=true)` so that by default a positive QR decomposition will be used. @@ -133,36 +139,38 @@ end # Null functions # -------------- """ - left_null(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> N - left_null!(A, [N]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> N + left_null(A; [kind::Symbol, trunc, alg_qr, alg_svd]) -> N + left_null!(A, [N]; [kind::Symbol, alg_qr, alg_svd]) -> N Compute an orthonormal basis `N` for the cokernel of the matrix `A` of size `(m, n)`, i.e. the nullspace of `adjoint(A)`, such that `adjoint(A)*N ≈ 0` and `N'*N ≈ I`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the -precision in determining the rank of `A` via its singular values. +that should be used to factor `A`, whereas `trunc` can be used to control the +the rank of `A` via its singular values. + +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxrank`. This is a high-level wrapper and will use one of the decompositions `qr!` or `svd!` to compute the orthogonal basis `N`, as controlled by the keyword arguments. When `kind` is provided, its possible values are -* `kind == :qrpos`: `N` is computed using the positive QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `left_null!(A, [N], kind=:qrpos)` is equivalent to +* `kind == :qr`: `N` is computed using the QR decomposition. + This requires `isnothing(trunc)` and `left_null!(A, [N], kind=:qr)` is equivalent to `qr_null!(A, [N], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` -* `kind == :qr`: `N` is computed using the (nonpositive) QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `left_null!(A, [N], kind=:qr)` is equivalent to - `qr_null!(A, [N], alg)` with a default value `alg = select_algorithm(qr_compact!, A)` - * `kind == :svd`: `N` is computed using the singular value decomposition and will contain - the left singular vectors corresponding to the singular values that - are smaller than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. + the left singular vectors corresponding to either the zero singular values if `trunc` + isn't specified or the singular values specified by `trunc`. -When `kind` is not provided, the default value is `:qrpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +using the `alg_qr` and `alg_svd` keyword arguments, which will only be used by the corresponding +factorization backend. If `alg_qr` or `alg_svd` are NamedTuples, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will +be used. !!! note The bang method `left_null!` optionally accepts the output structure and possibly destroys @@ -181,33 +189,32 @@ function left_null(A::AbstractMatrix; kwargs...) end """ - right_null(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> Nᴴ - right_null!(A, [Nᴴ]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> Nᴴ + right_null(A; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ + right_null!(A, [Nᴴ]; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ Compute an orthonormal basis `N = adjoint(Nᴴ)` for the kernel or nullspace of the matrix `A` of size `(m, n)`, such that `A*adjoint(Nᴴ) ≈ 0` and `Nᴴ*adjoint(Nᴴ) ≈ I`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the -precision in determining the rank of `A` via its singular values. +that should be used to factor `A`, whereas `trunc` can be used to control the +the rank of `A` via its singular values. + +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxrank`. This is a high-level wrapper and will use one of the decompositions `lq!` or `svd!` to compute the orthogonal basis `Nᴴ`, as controlled by the keyword arguments. When `kind` is provided, its possible values are -* `kind == :lqpos`: `Nᴴ` is computed using the positive LQ decomposition. - This requires `iszero(atol) && iszero(rtol)` and `right_null!(A, [Nᴴ], kind=:lqpos)` is equivalent to - `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` - * `kind == :lq`: `Nᴴ` is computed using the (nonpositive) LQ decomposition. - This requires `iszero(atol) && iszero(rtol)` and `right_null!(A, [Nᴴ], kind=:lq)` is equivalent to - `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A)` + This requires `isnothing(trunc)` and `right_null!(A, [Nᴴ], kind=:lq)` is equivalent to + `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` * `kind == :svd`: `N` is computed using the singular value decomposition and will contain the left singular vectors corresponding to the singular values that are smaller than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. -When `kind` is not provided, the default value is `:lqpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm using the `alg` keyword argument, which should be compatible with the chosen or default value of `kind`. diff --git a/test/orthnull.jl b/test/orthnull.jl index 048167e3..16ed0dee 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -36,7 +36,7 @@ using LinearAlgebra: LinearAlgebra, I atol = eps(real(T)) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; atol=atol) + N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=(; atol=atol)) @test V2 !== V @test C2 !== C @test N2 !== C @@ -48,7 +48,7 @@ using LinearAlgebra: LinearAlgebra, I rtol = eps(real(T)) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; rtol=rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; rtol=rtol) + N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=(; rtol=rtol)) @test V2 !== V @test C2 !== C @test N2 !== C @@ -77,7 +77,8 @@ using LinearAlgebra: LinearAlgebra, I if kind == :svd V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind, trunc=(; atol=atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, atol=atol) + N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; atol=atol)) @test V2 !== V @test C2 !== C @test N2 !== C @@ -89,7 +90,8 @@ using LinearAlgebra: LinearAlgebra, I V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind, trunc=(; rtol=rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, rtol=rtol) + N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; rtol=rtol)) @test V2 !== V @test C2 !== C @test N2 !== C @@ -103,8 +105,10 @@ using LinearAlgebra: LinearAlgebra, I trunc=(; atol=atol)) @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind, trunc=(; rtol=rtol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, atol=atol) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, rtol=rtol) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; atol=atol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; rtol=rtol)) end end end @@ -142,7 +146,7 @@ end atol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; atol=atol) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; atol=atol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -154,7 +158,7 @@ end rtol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; rtol=rtol) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; rtol=rtol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -182,7 +186,8 @@ end if kind == :svd C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, trunc=(; atol=atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, atol=atol) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, + trunc=(; atol=atol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -194,7 +199,8 @@ end C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, trunc=(; rtol=rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, rtol=rtol) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, + trunc=(; rtol=rtol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -209,9 +215,9 @@ end @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, trunc=(; rtol=rtol)) @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind, - atol=atol) + trunc=(; atol=atol)) @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind, - rtol=rtol) + trunc=(; rtol=rtol)) end end end From 8fd5a1d930f3a57126504482b23f696e761a561f Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 16:41:19 -0400 Subject: [PATCH 10/14] Add tests --- test/orthnull.jl | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/test/orthnull.jl b/test/orthnull.jl index 16ed0dee..4a772c4c 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -3,6 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, I +using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow @testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -22,6 +23,19 @@ using LinearAlgebra: LinearAlgebra, I @test N' * N ≈ I @test V * V' + N * N' ≈ I + for alg_qr in ((; positive=true), (; positive=false), LAPACK_HouseholderQR()) + V, C = @constinferred left_orth(A; alg_qr) + N = @constinferred left_null(A; alg_qr) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test V' * V ≈ I + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test N' * N ≈ I + @test V * V' + N * N' ≈ I + end + Ac = similar(A) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) N2 = @constinferred left_null!(copy!(Ac, A), N) @@ -47,16 +61,19 @@ using LinearAlgebra: LinearAlgebra, I @test V2 * V2' + N2 * N2' ≈ I rtol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; rtol=rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=(; rtol=rtol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C - @test V2 * C2 ≈ A - @test V2' * V2 ≈ I - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test N2' * N2 ≈ I - @test V2 * V2' + N2 * N2' ≈ I + for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)), + (TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol))) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth) + N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null) + @test V2 !== V + @test C2 !== C + @test N2 !== C + @test V2 * C2 ≈ A + @test V2' * V2 ≈ I + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test N2' * N2 ≈ I + @test V2 * V2' + N2 * N2' ≈ I + end for kind in (:qr, :polar, :svd) # explicit kind kwarg m < n && kind == :polar && continue From cc74b85afeff79b4980cede683f0138cf7cd66dc Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 17:19:44 -0400 Subject: [PATCH 11/14] Change maxrank to maxnullity in null methods --- src/implementations/orthnull.jl | 6 +++--- src/interface/orthnull.jl | 4 ++-- test/orthnull.jl | 13 +++++++++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 849b2202..e46351c7 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -156,14 +156,14 @@ end # Implementation of null functions # -------------------------------- -function null_truncation_strategy(; atol=nothing, rtol=nothing, maxrank=nothing) - if isnothing(maxrank) && isnothing(atol) && isnothing(rtol) +function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing) + if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) return NoTruncation() end atol = @something atol 0 rtol = @something rtol 0 trunc = TruncationKeepBelow(atol, rtol) - return !isnothing(maxrank) ? trunc & truncrank(maxrank; rev=false) : trunc + return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc end function left_null!(A::AbstractMatrix, N; trunc=nothing, diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 8c858fbc..6b66fc4c 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -149,7 +149,7 @@ that should be used to factor `A`, whereas `trunc` can be used to control the the rank of `A` via its singular values. `trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxrank`. +`atol`, `rtol`, and `maxnullity`. This is a high-level wrapper and will use one of the decompositions `qr!` or `svd!` to compute the orthogonal basis `N`, as controlled by the keyword arguments. @@ -199,7 +199,7 @@ that should be used to factor `A`, whereas `trunc` can be used to control the the rank of `A` via its singular values. `trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxrank`. +`atol`, `rtol`, and `maxnullity`. This is a high-level wrapper and will use one of the decompositions `lq!` or `svd!` to compute the orthogonal basis `Nᴴ`, as controlled by the keyword arguments. diff --git a/test/orthnull.jl b/test/orthnull.jl index 4a772c4c..0823d3c5 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -23,6 +23,19 @@ using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow @test N' * N ≈ I @test V * V' + N * N' ≈ I + if m > n + nullity = 5 + V, C = @constinferred left_orth(A) + N = @constinferred left_null(A; trunc=(; maxnullity=nullity)) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, nullity) + @test V * C ≈ A + @test V' * V ≈ I + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test N' * N ≈ I + end + for alg_qr in ((; positive=true), (; positive=false), LAPACK_HouseholderQR()) V, C = @constinferred left_orth(A; alg_qr) N = @constinferred left_null(A; alg_qr) From 022ab9a60a7a9a6188605c836c14b2d36b1b6c6b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 24 Apr 2025 12:46:10 -0400 Subject: [PATCH 12/14] Apply suggestions --- src/implementations/orthnull.jl | 5 ----- src/interface/orthnull.jl | 28 +++++++++++++++------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index e46351c7..e19ca04d 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -181,11 +181,6 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing, U, _, _ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(N, view(U, 1:m, (n + 1):m)) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) - U, _, _ = svd_full!(A, alg_svd′) - (m, n) = size(A) - return copy!(N, view(U, 1:m, (n + 1):m)) elseif kind == :svd alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) U, S, _ = svd_full!(A, alg_svd′) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 6b66fc4c..9d004c9a 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -32,8 +32,8 @@ precision in determining the rank of `A` via its singular values. `atol`, `rtol`, and `maxrank`. This is a high-level wrapper and will use one of the decompositions -`qr_compact!`, `svd_compact!`/`svd_trunc!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled -by the keyword arguments. +[`qr_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and[`left_polar!`](@ref) +to compute the orthogonal basis `V`, as controlled by the keyword arguments. When `kind` is provided, its possible values are @@ -45,9 +45,8 @@ When `kind` is provided, its possible values are This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to `left_polar!(A, [VC], alg)` with a default value `alg = select_algorithm(left_polar!, A)` -* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_compact!` - if no truncation is specified through the `trunc` keyword argument or `svd_trunc!` - if truncation is specified through the `trunc` keyword argument. +* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_trunc!` when a + truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. `V` will contain the left singular vectors and `C` is computed as the product of the singular values and the right singular vectors, i.e. with `U, S, Vᴴ = svd(A)`, we have `V = U` and `C = S * Vᴴ`. @@ -91,8 +90,9 @@ precision in determining the rank of `A` via its singular values. `atol`, `rtol`, and `maxrank`. This is a high-level wrapper and will use call one of the decompositions -`lq_compact!`, `svd_compact!`/`svd_trunc!`, and `right_polar!` to compute the -orthogonal basis `V`, as controlled by the keyword arguments. +[`lq_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and +[`right_polar!`](@ref) to compute the orthogonal basis `V`, as controlled by the +keyword arguments. When `kind` is provided, its possible values are @@ -104,9 +104,8 @@ When `kind` is provided, its possible values are This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to `right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A)` -* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_compact!` - if no truncation is specified through the `trunc` keyword argument or `svd_trunc!` - if truncation is specified through the `trunc` keyword argument. +* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_trunc!` when + a truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular values and `C` is computed as the product of the singular values and the right singular vectors, i.e. with `U, S, Vᴴ = svd(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`. @@ -117,7 +116,7 @@ for backend factorizations through the `alg_lq`, `alg_polar`, and `alg_svd` keyw which will only be used if the corresponding factorization is called based on the other inputs. If `alg_lq`, `alg_polar`, or `alg_svd` are NamedTuples, a default algorithm is chosen with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_lq` defaults to `(; positive=true)` so that by default a positive QR decomposition will +`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will be used. !!! note @@ -216,8 +215,11 @@ When `kind` is provided, its possible values are When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +using the `alg_lq` and `alg_svd` keyword arguments, which will only be used by the corresponding +factorization backend. If `alg_lq` or `alg_svd` are NamedTuples, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will +be used. !!! note The bang method `right_null!` optionally accepts the output structure and possibly destroys From 5892481e618c7254953e16e6b2d0f2b74c6adddb Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 24 Apr 2025 15:08:27 -0400 Subject: [PATCH 13/14] Fix typo in docstring Co-authored-by: Lukas Devos --- src/interface/orthnull.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 9d004c9a..4b00ec66 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -89,7 +89,7 @@ precision in determining the rank of `A` via its singular values. `trunc` can either be a truncation strategy object or a NamedTuple with fields `atol`, `rtol`, and `maxrank`. -This is a high-level wrapper and will use call one of the decompositions +This is a high-level wrapper and will use one of the decompositions [`lq_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and [`right_polar!`](@ref) to compute the orthogonal basis `V`, as controlled by the keyword arguments. From e14e9b2e1b8dfa7446304d68769e726f781cf156 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 24 Apr 2025 15:27:58 -0400 Subject: [PATCH 14/14] Move algorithm selection logic (#1) --- src/algorithms.jl | 7 +++++++ src/implementations/orthnull.jl | 35 +++++++++++++-------------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index fc247317..34a3f3d1 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -61,6 +61,13 @@ implementing the function `f` on inputs of type `A`. """ function select_algorithm end +function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm) + return alg +end +function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple) + return select_algorithm(f, A; alg...) +end + @doc """ copy_input(f, A) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index e19ca04d..2680e7cc 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -79,13 +79,6 @@ function initialize_output(::typeof(right_null!), A::AbstractMatrix) return Nᴴ end -function algorithm_or_select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm) - return alg -end -function algorithm_or_select_algorithm(f, A::AbstractMatrix, kwargs::NamedTuple) - return select_algorithm(f, A; kwargs...) -end - # Implementation of orth functions # -------------------------------- function left_orth!(A::AbstractMatrix, VC; trunc=nothing, @@ -96,21 +89,21 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing, throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg_qr′ = algorithm_or_select_algorithm(qr_compact!, A, alg_qr) + alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr) return qr_compact!(A, VC, alg_qr′) elseif kind == :polar size(A, 1) >= size(A, 2) || throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg_polar′ = algorithm_or_select_algorithm(left_polar!, A, alg_polar) + alg_polar′ = _select_algorithm(left_polar!, A, alg_polar) return left_polar!(A, VC, alg_polar′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′) return U, lmul!(S, Vᴴ) elseif kind == :svd - alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) @@ -129,21 +122,21 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end if kind == :lq - alg_lq′ = algorithm_or_select_algorithm(lq_compact!, A, alg_lq) + alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq) return lq_compact!(A, CVᴴ, alg_lq′) elseif kind == :polar size(A, 2) >= size(A, 1) || throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`")) - alg_polar′ = algorithm_or_select_algorithm(right_polar!, A, alg_polar) + alg_polar′ = _select_algorithm(right_polar!, A, alg_polar) return right_polar!(A, CVᴴ, alg_polar′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′) return rmul!(U, S), Vᴴ elseif kind == :svd - alg_svd′ = algorithm_or_select_algorithm(svd_compact!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) @@ -174,15 +167,15 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing, throw(ArgumentError("truncation not supported for left_null with kind=$kind")) end if kind == :qr - alg_qr′ = algorithm_or_select_algorithm(qr_null!, A, alg_qr) + alg_qr′ = _select_algorithm(qr_null!, A, alg_qr) return qr_null!(A, N, alg_qr′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) U, _, _ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(N, view(U, 1:m, (n + 1):m)) elseif kind == :svd - alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) U, S, _ = svd_full!(A, alg_svd′) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : @@ -201,15 +194,15 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_null with kind=$kind")) end if kind == :lq - alg_lq′ = algorithm_or_select_algorithm(lq_null!, A, alg_lq) + alg_lq′ = _select_algorithm(lq_null!, A, alg_lq) return lq_null!(A, Nᴴ, alg_lq′) elseif kind == :svd && isnothing(trunc) - alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) _, _, Vᴴ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) elseif kind == :svd - alg_svd′ = algorithm_or_select_algorithm(svd_full!, A, alg_svd) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) _, S, Vᴴ = svd_full!(A, alg_svd′) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :