Skip to content

Commit 1b44845

Browse files
authored
More general truncation and algorithm selection in orth/null (#19)
1 parent 2364e25 commit 1b44845

File tree

5 files changed

+240
-190
lines changed

5 files changed

+240
-190
lines changed

src/algorithms.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ implementing the function `f` on inputs of type `A`.
6161
"""
6262
function select_algorithm end
6363

64+
function _select_algorithm(f, A::AbstractMatrix, alg::AbstractAlgorithm)
65+
return alg
66+
end
67+
function _select_algorithm(f, A::AbstractMatrix, alg::NamedTuple)
68+
return select_algorithm(f, A; alg...)
69+
end
70+
6471
@doc """
6572
copy_input(f, A)
6673

src/implementations/orthnull.jl

Lines changed: 76 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -81,76 +81,66 @@ end
8181

8282
# Implementation of orth functions
8383
# --------------------------------
84-
function left_orth!(A::AbstractMatrix, VC; kwargs...)
84+
function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
85+
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
86+
alg_polar=(;), alg_svd=(;))
8587
check_input(left_orth!, A, VC)
86-
atol = get(kwargs, :atol, 0)
87-
rtol = get(kwargs, :rtol, 0)
88-
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd)
89-
if !(iszero(atol) && iszero(rtol)) && kind != :svd
90-
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
88+
if !isnothing(trunc) && kind != :svd
89+
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
9190
end
9291
if kind == :qr
93-
alg = get(kwargs, :alg, select_algorithm(qr_compact!, A))
94-
return qr_compact!(A, VC, alg)
95-
elseif kind == :qrpos
96-
alg = get(kwargs, :alg, select_algorithm(qr_compact!, A; positive=true))
97-
return qr_compact!(A, VC, alg)
92+
alg_qr′ = _select_algorithm(qr_compact!, A, alg_qr)
93+
return qr_compact!(A, VC, alg_qr′)
9894
elseif kind == :polar
9995
size(A, 1) >= size(A, 2) ||
10096
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
101-
alg = get(kwargs, :alg, select_algorithm(left_polar!, A))
102-
return left_polar!(A, VC, alg)
103-
elseif kind == :svd && iszero(atol) && iszero(rtol)
104-
alg = get(kwargs, :alg, select_algorithm(svd_compact!, A))
97+
alg_polar′ = _select_algorithm(left_polar!, A, alg_polar)
98+
return left_polar!(A, VC, alg_polar′)
99+
elseif kind == :svd && isnothing(trunc)
100+
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
105101
V, C = VC
106-
S = Diagonal(initialize_output(svd_vals!, A, alg))
107-
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg)
102+
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
103+
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
108104
return U, lmul!(S, Vᴴ)
109105
elseif kind == :svd
110-
alg_svd = select_algorithm(svd_compact!, A)
111-
trunc = TruncationKeepAbove(atol, rtol)
112-
alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc))
106+
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
107+
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
113108
V, C = VC
114-
S = Diagonal(initialize_output(svd_vals!, A, alg_svd))
115-
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg)
109+
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
110+
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
116111
return U, lmul!(S, Vᴴ)
117112
else
118113
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
119114
end
120115
end
121116

122-
function right_orth!(A::AbstractMatrix, CVᴴ; kwargs...)
117+
function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
118+
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
119+
alg_polar=(;), alg_svd=(;))
123120
check_input(right_orth!, A, CVᴴ)
124-
atol = get(kwargs, :atol, 0)
125-
rtol = get(kwargs, :rtol, 0)
126-
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd)
127-
if !(iszero(atol) && iszero(rtol)) && kind != :svd
128-
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
121+
if !isnothing(trunc) && kind != :svd
122+
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
129123
end
130124
if kind == :lq
131-
alg = get(kwargs, :alg, select_algorithm(lq_compact!, A))
132-
return lq_compact!(A, CVᴴ, alg)
133-
elseif kind == :lqpos
134-
alg = get(kwargs, :alg, select_algorithm(lq_compact!, A; positive=true))
135-
return lq_compact!(A, CVᴴ, alg)
125+
alg_lq′ = _select_algorithm(lq_compact!, A, alg_lq)
126+
return lq_compact!(A, CVᴴ, alg_lq′)
136127
elseif kind == :polar
137128
size(A, 2) >= size(A, 1) ||
138129
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
139-
alg = get(kwargs, :alg, select_algorithm(right_polar!, A))
140-
return right_polar!(A, CVᴴ, alg)
141-
elseif kind == :svd && iszero(atol) && iszero(rtol)
142-
alg = get(kwargs, :alg, select_algorithm(svd_compact!, A))
130+
alg_polar′ = _select_algorithm(right_polar!, A, alg_polar)
131+
return right_polar!(A, CVᴴ, alg_polar′)
132+
elseif kind == :svd && isnothing(trunc)
133+
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
143134
C, Vᴴ = CVᴴ
144-
S = Diagonal(initialize_output(svd_vals!, A, alg))
145-
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg)
135+
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
136+
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
146137
return rmul!(U, S), Vᴴ
147138
elseif kind == :svd
148-
alg_svd = select_algorithm(svd_compact!, A)
149-
trunc = TruncationKeepAbove(atol, rtol)
150-
alg = get(kwargs, :alg, TruncatedAlgorithm(alg_svd, trunc))
139+
alg_svd′ = _select_algorithm(svd_compact!, A, alg_svd)
140+
alg_svd_trunc = select_algorithm(svd_trunc!, A; trunc, alg=alg_svd′)
151141
C, Vᴴ = CVᴴ
152-
S = Diagonal(initialize_output(svd_vals!, A, alg_svd))
153-
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg)
142+
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
143+
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
154144
return rmul!(U, S), Vᴴ
155145
else
156146
throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`"))
@@ -159,59 +149,65 @@ end
159149

160150
# Implementation of null functions
161151
# --------------------------------
162-
function left_null!(A::AbstractMatrix, N; kwargs...)
152+
function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothing)
153+
if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol)
154+
return NoTruncation()
155+
end
156+
atol = @something atol 0
157+
rtol = @something rtol 0
158+
trunc = TruncationKeepBelow(atol, rtol)
159+
return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc
160+
end
161+
162+
function left_null!(A::AbstractMatrix, N; trunc=nothing,
163+
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
164+
alg_svd=(;))
163165
check_input(left_null!, A, N)
164-
atol = get(kwargs, :atol, 0)
165-
rtol = get(kwargs, :rtol, 0)
166-
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :qrpos : :svd)
167-
if !(iszero(atol) && iszero(rtol)) && kind != :svd
168-
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
166+
if !isnothing(trunc) && kind != :svd
167+
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
169168
end
170169
if kind == :qr
171-
alg = get(kwargs, :alg, select_algorithm(qr_null!, A))
172-
return qr_null!(A, N, alg)
173-
elseif kind == :qrpos
174-
alg = get(kwargs, :alg, select_algorithm(qr_null!, A; positive=true))
175-
return qr_null!(A, N, alg)
176-
elseif kind == :svd && iszero(atol) && iszero(rtol)
177-
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
178-
U, _, _ = svd_full!(A, alg)
170+
alg_qr′ = _select_algorithm(qr_null!, A, alg_qr)
171+
return qr_null!(A, N, alg_qr′)
172+
elseif kind == :svd && isnothing(trunc)
173+
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
174+
U, _, _ = svd_full!(A, alg_svd′)
179175
(m, n) = size(A)
180176
return copy!(N, view(U, 1:m, (n + 1):m))
181177
elseif kind == :svd
182-
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
183-
U, S, _ = svd_full!(A, alg)
184-
trunc = TruncationKeepBelow(atol, rtol)
185-
return truncate!(left_null!, (U, S), trunc)
178+
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
179+
U, S, _ = svd_full!(A, alg_svd′)
180+
trunc′ = trunc isa TruncationStrategy ? trunc :
181+
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
182+
throw(ArgumentError("Unknown truncation strategy: $trunc"))
183+
return truncate!(left_null!, (U, S), trunc′)
186184
else
187185
throw(ArgumentError("`left_null!` received unknown value `kind = $kind`"))
188186
end
189187
end
190188

191-
function right_null!(A::AbstractMatrix, Nᴴ; kwargs...)
189+
function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
190+
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
191+
alg_svd=(;))
192192
check_input(right_null!, A, Nᴴ)
193-
atol = get(kwargs, :atol, 0)
194-
rtol = get(kwargs, :rtol, 0)
195-
kind = get(kwargs, :kind, iszero(atol) && iszero(rtol) ? :lqpos : :svd)
196-
if !(iszero(atol) && iszero(rtol)) && kind != :svd
197-
throw(ArgumentError("nonzero tolerance not supported for left_orth with kind=$kind"))
193+
if !isnothing(trunc) && kind != :svd
194+
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
198195
end
199196
if kind == :lq
200-
alg = get(kwargs, :alg, select_algorithm(lq_null!, A))
201-
return lq_null!(A, Nᴴ, alg)
202-
elseif kind == :lqpos
203-
alg = get(kwargs, :alg, select_algorithm(lq_null!, A; positive=true))
204-
return lq_null!(A, Nᴴ, alg)
205-
elseif kind == :svd && iszero(atol) && iszero(rtol)
206-
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
207-
_, _, Vᴴ = svd_full!(A, alg)
197+
alg_lq′ = _select_algorithm(lq_null!, A, alg_lq)
198+
return lq_null!(A, Nᴴ, alg_lq′)
199+
elseif kind == :svd && isnothing(trunc)
200+
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
201+
_, _, Vᴴ = svd_full!(A, alg_svd′)
208202
(m, n) = size(A)
209203
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
210204
elseif kind == :svd
211-
alg = get(kwargs, :alg, select_algorithm(svd_full!, A))
212-
_, S, Vᴴ = svd_full!(A, alg)
213-
trunc = TruncationKeepBelow(atol, rtol)
214-
return truncate!(right_null!, (S, Vᴴ), trunc)
205+
alg_svd′ = _select_algorithm(svd_full!, A, alg_svd)
206+
_, S, Vᴴ = svd_full!(A, alg_svd′)
207+
trunc′ = trunc isa TruncationStrategy ? trunc :
208+
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
209+
throw(ArgumentError("Unknown truncation strategy: $trunc"))
210+
return truncate!(right_null!, (S, Vᴴ), trunc′)
215211
else
216212
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
217213
end

0 commit comments

Comments
 (0)