diff --git a/src/algorithms.jl b/src/algorithms.jl index fc247317..34a3f3d1 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -61,6 +61,13 @@ implementing the function `f` on inputs of type `A`. """ function select_algorithm end +function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm) + return alg +end +function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple) + return select_algorithm(f, A; alg...) +end + @doc """ copy_input(f, A) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index c34df721..2680e7cc 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -81,76 +81,66 @@ end # Implementation of orth functions # -------------------------------- -function left_orth!(A::AbstractMatrix, VC; kwargs...) +function left_orth!(A::AbstractMatrix, VC; trunc=nothing, + kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), + alg_polar=(;), alg_svd=(;)) check_input(left_orth!, A, VC) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg = get(kwargs, :alg, select_algorithm(qr_compact!, A)) - return qr_compact!(A, VC, alg) - elseif kind == :qrpos - alg = get(kwargs, :alg, select_algorithm(qr_compact!, A; positive=true)) - return qr_compact!(A, VC, alg) + alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr) + return qr_compact!(A, VC, alg_qr′) elseif kind == :polar size(A, 1) >= size(A, 2) || throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg = get(kwargs, :alg, select_algorithm(left_polar!, A)) - return left_polar!(A, VC, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_compact!, A)) + alg_polar′ = _select_algorithm(left_polar!, A, alg_polar) + return left_polar!(A, VC, alg_polar′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg)) - U, S, Vᴴ = svd_compact!(A, (V, S, C), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) + U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′) return U, lmul!(S, Vᴴ) elseif kind == :svd - alg_svd = select_algorithm(svd_compact!, A) - trunc = TruncationKeepAbove(atol, rtol) - alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc)) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) + alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_svd)) - U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc) return U, lmul!(S, Vᴴ) else throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) end end -function right_orth!(A::AbstractMatrix, CVᴴ; kwargs...) +function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, + kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), + alg_polar=(;), alg_svd=(;)) check_input(right_orth!, A, CVᴴ) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end if kind == :lq - alg = get(kwargs, :alg, select_algorithm(lq_compact!, A)) - return lq_compact!(A, CVᴴ, alg) - elseif kind == :lqpos - alg = get(kwargs, :alg, select_algorithm(lq_compact!, A; positive=true)) - return lq_compact!(A, CVᴴ, alg) + alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq) + return lq_compact!(A, CVᴴ, alg_lq′) elseif kind == :polar size(A, 2) >= size(A, 1) || throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`")) - alg = get(kwargs, :alg, select_algorithm(right_polar!, A)) - return right_polar!(A, CVᴴ, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_compact!, A)) + alg_polar′ = _select_algorithm(right_polar!, A, alg_polar) + return right_polar!(A, CVᴴ, alg_polar′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg)) - U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) + U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′) return rmul!(U, S), Vᴴ elseif kind == :svd - alg_svd = select_algorithm(svd_compact!, A) - trunc = TruncationKeepAbove(atol, rtol) - alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc)) + alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd) + alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′) C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_svd)) - U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg) + S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc) return rmul!(U, S), Vᴴ else throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`")) @@ -159,59 +149,65 @@ end # Implementation of null functions # -------------------------------- -function left_null!(A::AbstractMatrix, N; kwargs...) +function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing) + if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) + return NoTruncation() + end + atol = @something atol 0 + rtol = @something rtol 0 + trunc = TruncationKeepBelow(atol, rtol) + return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc +end + +function left_null!(A::AbstractMatrix, N; trunc=nothing, + kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), + alg_svd=(;)) check_input(left_null!, A, N) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for left_null with kind=$kind")) end if kind == :qr - alg = get(kwargs, :alg, select_algorithm(qr_null!, A)) - return qr_null!(A, N, alg) - elseif kind == :qrpos - alg = get(kwargs, :alg, select_algorithm(qr_null!, A; positive=true)) - return qr_null!(A, N, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - U, _, _ = svd_full!(A, alg) + alg_qr′ = _select_algorithm(qr_null!, A, alg_qr) + return qr_null!(A, N, alg_qr′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + U, _, _ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(N, view(U, 1:m, (n + 1):m)) elseif kind == :svd - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - U, S, _ = svd_full!(A, alg) - trunc = TruncationKeepBelow(atol, rtol) - return truncate!(left_null!, (U, S), trunc) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + U, S, _ = svd_full!(A, alg_svd′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(left_null!, (U, S), trunc′) else throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) end end -function right_null!(A::AbstractMatrix, Nᴴ; kwargs...) +function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, + kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), + alg_svd=(;)) check_input(right_null!, A, Nᴴ) - atol = get(kwargs, :atol, 0) - rtol = get(kwargs, :rtol, 0) - kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd) - if !(iszero(atol) && iszero(rtol)) && kind != :svd - throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind")) + if !isnothing(trunc) && kind != :svd + throw(ArgumentError("truncation not supported for right_null with kind=$kind")) end if kind == :lq - alg = get(kwargs, :alg, select_algorithm(lq_null!, A)) - return lq_null!(A, Nᴴ, alg) - elseif kind == :lqpos - alg = get(kwargs, :alg, select_algorithm(lq_null!, A; positive=true)) - return lq_null!(A, Nᴴ, alg) - elseif kind == :svd && iszero(atol) && iszero(rtol) - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - _, _, Vᴴ = svd_full!(A, alg) + alg_lq′ = _select_algorithm(lq_null!, A, alg_lq) + return lq_null!(A, Nᴴ, alg_lq′) + elseif kind == :svd && isnothing(trunc) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + _, _, Vᴴ = svd_full!(A, alg_svd′) (m, n) = size(A) return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) elseif kind == :svd - alg = get(kwargs, :alg, select_algorithm(svd_full!, A)) - _, S, Vᴴ = svd_full!(A, alg) - trunc = TruncationKeepBelow(atol, rtol) - return truncate!(right_null!, (S, Vᴴ), trunc) + alg_svd′ = _select_algorithm(svd_full!, A, alg_svd) + _, S, Vᴴ = svd_full!(A, alg_svd′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(right_null!, (S, Vᴴ), trunc′) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 236184bd..4b00ec66 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -19,43 +19,46 @@ end # Orth functions # -------------- """ - left_orth(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> V, C - left_orth!(A, [VC]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> V, C + left_orth(A; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C + left_orth!(A, [VC]; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C Compute an orthonormal basis `V` for the image of the matrix `A` of size `(m, n)`, as well as a matrix `C` (the corestriction) such that `A` factors as `A = V * C`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the +that should be used to factor `A`, whereas `trunc` can be used to control the precision in determining the rank of `A` via its singular values. +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxrank`. + This is a high-level wrapper and will use one of the decompositions -`qr!`, `svd!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled -by the keyword arguments. +[`qr_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and[`left_polar!`](@ref) +to compute the orthogonal basis `V`, as controlled by the keyword arguments. When `kind` is provided, its possible values are -* `kind == :qrpos`: `V` and `C` are computed using the positive QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to +* `kind == :qr`: `V` and `C` are computed using the QR decomposition. + This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to `qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` -* `kind == :qr`: `V` and `C` are computed using the QR decomposition, - This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to - `qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A)` - * `kind == :polar`: `V` and `C` are computed using the polar decomposition, - This requires `iszero(atol) && iszero(rtol)` and `left_orth!(A, [VC])` is equivalent to + This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to `left_polar!(A, [VC], alg)` with a default value `alg = select_algorithm(left_polar!, A)` -* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_trunc!`, - where `V` will contain the left singular vectors corresponding to the singular values that - are larger than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. - `C` is computed as the product of the singular values and the right singular vectors, - i.e. with `U, S, Vᴴ = svd_trunc!(A)`, we have `V = U` and `C = S * Vᴴ`. +* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_trunc!` when a + truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. + `V` will contain the left singular vectors and `C` is computed as the product of the singular + values and the right singular vectors, i.e. with `U, S, Vᴴ = svd(A)`, we have + `V = U` and `C = S * Vᴴ`. -When `kind` is not provided, the default value is `:qrpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +for backend factorizations through the `alg_qr`, `alg_polar`, and `alg_svd` keyword arguments, +which will only be used if the corresponding factorization is called based on the other inputs. +If NamedTuples are passed as `alg_qr`, `alg_polar`, or `alg_svd`, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will +be used. !!! note The bang method `left_orth!` optionally accepts the output structure and possibly destroys @@ -74,43 +77,47 @@ function left_orth(A::AbstractMatrix; kwargs...) end """ - right_orth(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> C, Vᴴ - right_orth!(A, [CVᴴ]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> C, Vᴴ + right_orth(A; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ + right_orth!(A, [CVᴴ]; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. for the image of `adjoint(A)`, as well as a matrix `C` such that `A = C * Vᴴ`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the +that should be used to factor `A`, whereas `trunc` can be used to control the precision in determining the rank of `A` via its singular values. -This is a high-level wrapper and will use call one of the decompositions -`qr!`, `svd!`, and `left_polar!` to compute the orthogonal basis `V`, as controlled -by the keyword arguments. +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxrank`. -When `kind` is provided, its possible values are +This is a high-level wrapper and will use one of the decompositions +[`lq_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and +[`right_polar!`](@ref) to compute the orthogonal basis `V`, as controlled by the +keyword arguments. -* `kind == :lqpos`: `C` and `Vᴴ` are computed using the positive QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to - `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` +When `kind` is provided, its possible values are * `kind == :lq`: `C` and `Vᴴ` are computed using the QR decomposition, - This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to - `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A))` + This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to + `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` * `kind == :polar`: `C` and `Vᴴ` are computed using the polar decomposition, - This requires `iszero(atol) && iszero(rtol)` and `right_orth!(A, [CVᴴ])` is equivalent to - `right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A))` + This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to + `right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A)` -* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_trunc!`, - where `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular - values that are larger than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. - `C` is computed as the product of the singular values and the right singular vectors, - i.e. with `U, S, Vᴴ = svd_trunc!(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`. +* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_trunc!` when + a truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. + `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular + values and `C` is computed as the product of the singular values and the right singular vectors, + i.e. with `U, S, Vᴴ = svd(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`. -When `kind` is not provided, the default value is `:lqpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +for backend factorizations through the `alg_lq`, `alg_polar`, and `alg_svd` keyword arguments, +which will only be used if the corresponding factorization is called based on the other inputs. +If `alg_lq`, `alg_polar`, or `alg_svd` are NamedTuples, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will +be used. !!! note The bang method `right_orth!` optionally accepts the output structure and possibly destroys @@ -131,36 +138,38 @@ end # Null functions # -------------- """ - left_null(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> N - left_null!(A, [N]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> N + left_null(A; [kind::Symbol, trunc, alg_qr, alg_svd]) -> N + left_null!(A, [N]; [kind::Symbol, alg_qr, alg_svd]) -> N Compute an orthonormal basis `N` for the cokernel of the matrix `A` of size `(m, n)`, i.e. the nullspace of `adjoint(A)`, such that `adjoint(A)*N ≈ 0` and `N'*N ≈ I`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the -precision in determining the rank of `A` via its singular values. +that should be used to factor `A`, whereas `trunc` can be used to control the +the rank of `A` via its singular values. + +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxnullity`. This is a high-level wrapper and will use one of the decompositions `qr!` or `svd!` to compute the orthogonal basis `N`, as controlled by the keyword arguments. When `kind` is provided, its possible values are -* `kind == :qrpos`: `N` is computed using the positive QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `left_null!(A, [N], kind=:qrpos)` is equivalent to +* `kind == :qr`: `N` is computed using the QR decomposition. + This requires `isnothing(trunc)` and `left_null!(A, [N], kind=:qr)` is equivalent to `qr_null!(A, [N], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` -* `kind == :qr`: `N` is computed using the (nonpositive) QR decomposition. - This requires `iszero(atol) && iszero(rtol)` and `left_null!(A, [N], kind=:qr)` is equivalent to - `qr_null!(A, [N], alg)` with a default value `alg = select_algorithm(qr_compact!, A)` - * `kind == :svd`: `N` is computed using the singular value decomposition and will contain - the left singular vectors corresponding to the singular values that - are smaller than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. + the left singular vectors corresponding to either the zero singular values if `trunc` + isn't specified or the singular values specified by `trunc`. -When `kind` is not provided, the default value is `:qrpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +using the `alg_qr` and `alg_svd` keyword arguments, which will only be used by the corresponding +factorization backend. If `alg_qr` or `alg_svd` are NamedTuples, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will +be used. !!! note The bang method `left_null!` optionally accepts the output structure and possibly destroys @@ -179,36 +188,38 @@ function left_null(A::AbstractMatrix; kwargs...) end """ - right_null(A; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> Nᴴ - right_null!(A, [Nᴴ]; [kind::Symbol, atol::Real=0, rtol::Real=0, alg]) -> Nᴴ + right_null(A; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ + right_null!(A, [Nᴴ]; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ Compute an orthonormal basis `N = adjoint(Nᴴ)` for the kernel or nullspace of the matrix `A` of size `(m, n)`, such that `A*adjoint(Nᴴ) ≈ 0` and `Nᴴ*adjoint(Nᴴ) ≈ I`. The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `atol` and `rtol` can be used to control the -precision in determining the rank of `A` via its singular values. +that should be used to factor `A`, whereas `trunc` can be used to control the +the rank of `A` via its singular values. + +`trunc` can either be a truncation strategy object or a NamedTuple with fields +`atol`, `rtol`, and `maxnullity`. This is a high-level wrapper and will use one of the decompositions `lq!` or `svd!` to compute the orthogonal basis `Nᴴ`, as controlled by the keyword arguments. When `kind` is provided, its possible values are -* `kind == :lqpos`: `Nᴴ` is computed using the positive LQ decomposition. - This requires `iszero(atol) && iszero(rtol)` and `right_null!(A, [Nᴴ], kind=:lqpos)` is equivalent to - `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` - * `kind == :lq`: `Nᴴ` is computed using the (nonpositive) LQ decomposition. - This requires `iszero(atol) && iszero(rtol)` and `right_null!(A, [Nᴴ], kind=:lq)` is equivalent to - `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A)` + This requires `isnothing(trunc)` and `right_null!(A, [Nᴴ], kind=:lq)` is equivalent to + `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` * `kind == :svd`: `N` is computed using the singular value decomposition and will contain the left singular vectors corresponding to the singular values that are smaller than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. -When `kind` is not provided, the default value is `:lqpos` when `iszero(atol) && iszero(rtol)` +When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg` keyword argument, which should be compatible with the chosen or default value -of `kind`. +using the `alg_lq` and `alg_svd` keyword arguments, which will only be used by the corresponding +factorization backend. If `alg_lq` or `alg_svd` are NamedTuples, a default algorithm is chosen +with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. +`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will +be used. !!! note The bang method `right_null!` optionally accepts the output structure and possibly destroys diff --git a/test/chainrules.jl b/test/chainrules.jl index ab895a4a..a48d2aae 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -338,26 +338,26 @@ end config = Zygote.ZygoteRuleConfig() test_rrule(config, left_orth, A; atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - test_rrule(config, left_orth, A; fkwargs=(; kind=:qrpos), + test_rrule(config, left_orth, A; fkwargs=(; kind=:qr), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) m >= n && test_rrule(config, left_orth, A; fkwargs=(; kind=:polar), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - ΔN = left_orth(A; kind=:qrpos)[1] * randn(rng, T, min(m, n), m - min(m, n)) - test_rrule(config, left_null, A; fkwargs=(; kind=:qrpos), output_tangent=ΔN, + ΔN = left_orth(A; kind=:qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + test_rrule(config, left_null, A; fkwargs=(; kind=:qr), output_tangent=ΔN, atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) test_rrule(config, right_orth, A; atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - test_rrule(config, right_orth, A; fkwargs=(; kind=:lqpos), + test_rrule(config, right_orth, A; fkwargs=(; kind=:lq), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) m <= n && test_rrule(config, right_orth, A; fkwargs=(; kind=:polar), atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lqpos)[2] - test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ, + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind=:lq)[2] + test_rrule(config, right_null, A; fkwargs=(; kind=:lq), output_tangent=ΔNᴴ, atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) end end diff --git a/test/orthnull.jl b/test/orthnull.jl index 0a93c593..0823d3c5 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -3,6 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, I +using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow @testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -22,6 +23,32 @@ using LinearAlgebra: LinearAlgebra, I @test N' * N ≈ I @test V * V' + N * N' ≈ I + if m > n + nullity = 5 + V, C = @constinferred left_orth(A) + N = @constinferred left_null(A; trunc=(; maxnullity=nullity)) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, nullity) + @test V * C ≈ A + @test V' * V ≈ I + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test N' * N ≈ I + end + + for alg_qr in ((; positive=true), (; positive=false), LAPACK_HouseholderQR()) + V, C = @constinferred left_orth(A; alg_qr) + N = @constinferred left_null(A; alg_qr) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test V' * V ≈ I + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test N' * N ≈ I + @test V * V' + N * N' ≈ I + end + Ac = similar(A) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) N2 = @constinferred left_null!(copy!(Ac, A), N) @@ -35,8 +62,8 @@ using LinearAlgebra: LinearAlgebra, I @test V2 * V2' + N2 * N2' ≈ I atol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); atol=atol) - N2 = @constinferred left_null!(copy!(Ac, A), N; atol=atol) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=(; atol=atol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=(; atol=atol)) @test V2 !== V @test C2 !== C @test N2 !== C @@ -47,18 +74,21 @@ using LinearAlgebra: LinearAlgebra, I @test V2 * V2' + N2 * N2' ≈ I rtol = eps(real(T)) - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); rtol=rtol) - N2 = @constinferred left_null!(copy!(Ac, A), N; rtol=rtol) - @test V2 !== V - @test C2 !== C - @test N2 !== C - @test V2 * C2 ≈ A - @test V2' * V2 ≈ I - @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test N2' * N2 ≈ I - @test V2 * V2' + N2 * N2' ≈ I + for (trunc_orth, trunc_null) in (((; rtol=rtol), (; rtol=rtol)), + (TruncationKeepAbove(0, rtol), TruncationKeepBelow(0, rtol))) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc=trunc_orth) + N2 = @constinferred left_null!(copy!(Ac, A), N; trunc=trunc_null) + @test V2 !== V + @test C2 !== C + @test N2 !== C + @test V2 * C2 ≈ A + @test V2' * V2 ≈ I + @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test N2' * N2 ≈ I + @test V2 * V2' + N2 * N2' ≈ I + end - for kind in (:qr, :qrpos, :polar, :svd) # explicit kind kwarg + for kind in (:qr, :polar, :svd) # explicit kind kwarg m < n && kind == :polar && continue V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind) @test V2 === V @@ -76,8 +106,9 @@ using LinearAlgebra: LinearAlgebra, I # with kind and tol kwargs if kind == :svd V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind, - atol=atol) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, atol=atol) + trunc=(; atol=atol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; atol=atol)) @test V2 !== V @test C2 !== C @test N2 !== C @@ -88,8 +119,9 @@ using LinearAlgebra: LinearAlgebra, I @test V2 * V2' + N2 * N2' ≈ I V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind=kind, - rtol=rtol) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, rtol=rtol) + trunc=(; rtol=rtol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; rtol=rtol)) @test V2 !== V @test C2 !== C @test N2 !== C @@ -100,11 +132,13 @@ using LinearAlgebra: LinearAlgebra, I @test V2 * V2' + N2 * N2' ≈ I else @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind, - atol=atol) + trunc=(; atol=atol)) @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind=kind, - rtol=rtol) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, atol=atol) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, rtol=rtol) + trunc=(; rtol=rtol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; atol=atol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind=kind, + trunc=(; rtol=rtol)) end end end @@ -141,8 +175,8 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I atol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); atol=atol) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; atol=atol) + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; atol=atol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; atol=atol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -153,8 +187,8 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I rtol = eps(real(T)) - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); rtol=rtol) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; rtol=rtol) + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc=(; rtol=rtol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc=(; rtol=rtol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -164,7 +198,7 @@ end @test Nᴴ2 * Nᴴ2' ≈ I @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - for kind in (:lq, :lqpos, :polar, :svd) + for kind in (:lq, :polar, :svd) n < m && kind == :polar && continue C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind) @test C2 === C @@ -181,8 +215,9 @@ end if kind == :svd C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - atol=atol) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, atol=atol) + trunc=(; atol=atol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, + trunc=(; atol=atol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -193,8 +228,9 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - rtol=rtol) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, rtol=rtol) + trunc=(; rtol=rtol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind=kind, + trunc=(; rtol=rtol)) @test C2 !== C @test Vᴴ2 !== Vᴴ @test Nᴴ2 !== Nᴴ @@ -205,13 +241,13 @@ end @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I else @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - atol=atol) + trunc=(; atol=atol)) @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind=kind, - rtol=rtol) + trunc=(; rtol=rtol)) @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind, - atol=atol) + trunc=(; atol=atol)) @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind=kind, - rtol=rtol) + trunc=(; rtol=rtol)) end end end