Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ implementing the function `f` on inputs of type `A`.
"""
function select_algorithm end

function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm)
return alg
end
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple)
return select_algorithm(f, A; alg...)
end

@doc """
copy_input(f, A)

Expand Down
156 changes: 76 additions & 80 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,76 +81,66 @@

# Implementation of orth functions
# --------------------------------
function left_orth!(A::AbstractMatrix, VC; kwargs...)
function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
alg_polar=(;), alg_svd=(;))
check_input(left_orth!, A, VC)
atol = get(kwargs, :atol, 0)
rtol = get(kwargs, :rtol, 0)
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd)
if !(iszero(atol) && iszero(rtol)) && kind != :svd
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
end
if kind == :qr
alg = get(kwargs, :alg, select_algorithm(qr_compact!, A))
return qr_compact!(A, VC, alg)
elseif kind == :qrpos
alg = get(kwargs, :alg, select_algorithm(qr_compact!, A; positive=true))
return qr_compact!(A, VC, alg)
alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr)
return qr_compact!(A, VC, alg_qr′)
elseif kind == :polar
size(A, 1) >= size(A, 2) ||
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
alg = get(kwargs, :alg, select_algorithm(left_polar!, A))
return left_polar!(A, VC, alg)
elseif kind == :svd && iszero(atol) && iszero(rtol)
alg = get(kwargs, :alg, select_algorithm(svd_compact!, A))
alg_polar′ = _select_algorithm(left_polar!, A, alg_polar)
return left_polar!(A, VC, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg))
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg)
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
return U, lmul!(S, Vᴴ)
elseif kind == :svd
alg_svd = select_algorithm(svd_compact!, A)
trunc = TruncationKeepAbove(atol, rtol)
alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc))
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_svd))
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg)
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
return U, lmul!(S, Vᴴ)
else
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
end
end

function right_orth!(A::AbstractMatrix, CVᴴ; kwargs...)
function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
alg_polar=(;), alg_svd=(;))
check_input(right_orth!, A, CVᴴ)
atol = get(kwargs, :atol, 0)
rtol = get(kwargs, :rtol, 0)
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd)
if !(iszero(atol) && iszero(rtol)) && kind != :svd
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
end
if kind == :lq
alg = get(kwargs, :alg, select_algorithm(lq_compact!, A))
return lq_compact!(A, CVᴴ, alg)
elseif kind == :lqpos
alg = get(kwargs, :alg, select_algorithm(lq_compact!, A; positive=true))
return lq_compact!(A, CVᴴ, alg)
alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq)
return lq_compact!(A, CVᴴ, alg_lq′)
elseif kind == :polar
size(A, 2) >= size(A, 1) ||
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
alg = get(kwargs, :alg, select_algorithm(right_polar!, A))
return right_polar!(A, CVᴴ, alg)
elseif kind == :svd && iszero(atol) && iszero(rtol)
alg = get(kwargs, :alg, select_algorithm(svd_compact!, A))
alg_polar′ = _select_algorithm(right_polar!, A, alg_polar)
return right_polar!(A, CVᴴ, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg))
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg)
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
return rmul!(U, S), Vᴴ
elseif kind == :svd
alg_svd = select_algorithm(svd_compact!, A)
trunc = TruncationKeepAbove(atol, rtol)
alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc))
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_svd))
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg)
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
return rmul!(U, S), Vᴴ
else
throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`"))
Expand All @@ -159,59 +149,65 @@

# Implementation of null functions
# --------------------------------
function left_null!(A::AbstractMatrix, N; kwargs...)
function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing)
if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol)
return NoTruncation()

Check warning on line 154 in src/implementations/orthnull.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/orthnull.jl#L154

Added line #L154 was not covered by tests
end
atol = @something atol 0
rtol = @something rtol 0
trunc = TruncationKeepBelow(atol, rtol)
return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc
end

function left_null!(A::AbstractMatrix, N; trunc=nothing,
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
alg_svd=(;))
check_input(left_null!, A, N)
atol = get(kwargs, :atol, 0)
rtol = get(kwargs, :rtol, 0)
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd)
if !(iszero(atol) && iszero(rtol)) && kind != :svd
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
end
if kind == :qr
alg = get(kwargs, :alg, select_algorithm(qr_null!, A))
return qr_null!(A, N, alg)
elseif kind == :qrpos
alg = get(kwargs, :alg, select_algorithm(qr_null!, A; positive=true))
return qr_null!(A, N, alg)
elseif kind == :svd && iszero(atol) && iszero(rtol)
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
U, _, _ = svd_full!(A, alg)
alg_qr′ = _select_algorithm(qr_null!, A, alg_qr)
return qr_null!(A, N, alg_qr′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
U, _, _ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(N, view(U, 1:m, (n + 1):m))
elseif kind == :svd
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
U, S, _ = svd_full!(A, alg)
trunc = TruncationKeepBelow(atol, rtol)
return truncate!(left_null!, (U, S), trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
U, S, _ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(left_null!, (U, S), trunc′)
else
throw(ArgumentError("`left_null!` received unknown value `kind = $kind`"))
end
end

function right_null!(A::AbstractMatrix, Nᴴ; kwargs...)
function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
alg_svd=(;))
check_input(right_null!, A, Nᴴ)
atol = get(kwargs, :atol, 0)
rtol = get(kwargs, :rtol, 0)
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd)
if !(iszero(atol) && iszero(rtol)) && kind != :svd
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
end
if kind == :lq
alg = get(kwargs, :alg, select_algorithm(lq_null!, A))
return lq_null!(A, Nᴴ, alg)
elseif kind == :lqpos
alg = get(kwargs, :alg, select_algorithm(lq_null!, A; positive=true))
return lq_null!(A, Nᴴ, alg)
elseif kind == :svd && iszero(atol) && iszero(rtol)
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
_, _, Vᴴ = svd_full!(A, alg)
alg_lq′ = _select_algorithm(lq_null!, A, alg_lq)
return lq_null!(A, Nᴴ, alg_lq′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
_, _, Vᴴ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
elseif kind == :svd
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
_, S, Vᴴ = svd_full!(A, alg)
trunc = TruncationKeepBelow(atol, rtol)
return truncate!(right_null!, (S, Vᴴ), trunc)
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
_, S, Vᴴ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(right_null!, (S, Vᴴ), trunc′)
else
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
end
Expand Down
Loading
Loading