Skip to content

Commit 253dc63

Browse files
authored
Make orthnull more customizable and general (#25)
1 parent c46119e commit 253dc63

File tree

3 files changed

+188
-79
lines changed

3 files changed

+188
-79
lines changed

src/implementations/orthnull.jl

Lines changed: 126 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Inputs
22
# ------
3-
copy_input(::typeof(left_orth), A::AbstractMatrix) = copy_input(qr_compact, A) # do we ever need anything else
4-
copy_input(::typeof(right_orth), A::AbstractMatrix) = copy_input(lq_compact, A) # do we ever need anything else
5-
copy_input(::typeof(left_null), A::AbstractMatrix) = copy_input(qr_null, A) # do we ever need anything else
6-
copy_input(::typeof(right_null), A::AbstractMatrix) = copy_input(lq_null, A) # do we ever need anything else
3+
copy_input(::typeof(left_orth), A) = copy_input(qr_compact, A) # do we ever need anything else
4+
copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever need anything else
5+
copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else
6+
copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else
77

88
function check_input(::typeof(left_orth!), A::AbstractMatrix, VC)
99
m, n = size(A)
@@ -81,71 +81,113 @@ end
8181

8282
# Implementation of orth functions
8383
# --------------------------------
84-
function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
84+
function left_orth!(A, VC; trunc=nothing,
8585
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
8686
alg_polar=(;), alg_svd=(;))
8787
check_input(left_orth!, A, VC)
8888
if !isnothing(trunc) && kind != :svd
8989
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
9090
end
9191
if kind == :qr
92-
alg_qr′ = select_algorithm(qr_compact!, A, alg_qr)
93-
return qr_compact!(A, VC, alg_qr′)
92+
return left_orth_qr!(A, VC, alg_qr)
9493
elseif kind == :polar
95-
size(A, 1) >= size(A, 2) ||
96-
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
97-
alg_polar′ = select_algorithm(left_polar!, A, alg_polar)
98-
return left_polar!(A, VC, alg_polar′)
99-
elseif kind == :svd && isnothing(trunc)
100-
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
101-
V, C = VC
102-
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
103-
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
104-
return U, lmul!(S, Vᴴ)
94+
return left_orth_polar!(A, VC, alg_polar)
10595
elseif kind == :svd
106-
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
107-
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
108-
V, C = VC
109-
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
110-
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
111-
return U, lmul!(S, Vᴴ)
96+
return left_orth_svd!(A, VC, alg_svd, trunc)
11297
else
11398
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
11499
end
115100
end
101+
function left_orth_qr!(A, VC, alg)
102+
alg′ = select_algorithm(qr_compact!, A, alg)
103+
return qr_compact!(A, VC, alg′)
104+
end
105+
function left_orth_polar!(A, VC, alg)
106+
alg′ = select_algorithm(left_polar!, A, alg)
107+
return left_polar!(A, VC, alg′)
108+
end
109+
function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing)
110+
alg′ = select_algorithm(svd_compact!, A, alg)
111+
U, S, Vᴴ = svd_compact!(A, alg′)
112+
V, C = VC
113+
return copy!(V, U), mul!(C, S, Vᴴ)
114+
end
115+
function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing=nothing)
116+
alg′ = select_algorithm(svd_compact!, A, alg)
117+
V, C = VC
118+
S = Diagonal(initialize_output(svd_vals!, A, alg′))
119+
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′)
120+
return U, lmul!(S, Vᴴ)
121+
end
122+
function left_orth_svd!(A, VC, alg, trunc)
123+
alg′ = select_algorithm(svd_compact!, A, alg)
124+
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
125+
U, S, Vᴴ = svd_trunc!(A, alg_trunc)
126+
V, C = VC
127+
return copy!(V, U), mul!(C, S, Vᴴ)
128+
end
129+
function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc)
130+
alg′ = select_algorithm(svd_compact!, A, alg)
131+
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
132+
V, C = VC
133+
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
134+
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc)
135+
return U, lmul!(S, Vᴴ)
136+
end
116137

117-
function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
138+
function right_orth!(A, CVᴴ; trunc=nothing,
118139
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
119140
alg_polar=(;), alg_svd=(;))
120141
check_input(right_orth!, A, CVᴴ)
121142
if !isnothing(trunc) && kind != :svd
122143
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
123144
end
124145
if kind == :lq
125-
alg_lq′ = select_algorithm(lq_compact!, A, alg_lq)
126-
return lq_compact!(A, CVᴴ, alg_lq′)
146+
return right_orth_lq!(A, CVᴴ, alg_lq)
127147
elseif kind == :polar
128-
size(A, 2) >= size(A, 1) ||
129-
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
130-
alg_polar′ = select_algorithm(right_polar!, A, alg_polar)
131-
return right_polar!(A, CVᴴ, alg_polar′)
132-
elseif kind == :svd && isnothing(trunc)
133-
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
134-
C, Vᴴ = CVᴴ
135-
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
136-
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
137-
return rmul!(U, S), Vᴴ
148+
return right_orth_polar!(A, CVᴴ, alg_polar)
138149
elseif kind == :svd
139-
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
140-
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
141-
C, Vᴴ = CVᴴ
142-
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
143-
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
144-
return rmul!(U, S), Vᴴ
150+
return right_orth_svd!(A, CVᴴ, alg_svd, trunc)
145151
else
146152
throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`"))
147153
end
148154
end
155+
function right_orth_lq!(A, CVᴴ, alg)
156+
alg′ = select_algorithm(lq_compact!, A, alg)
157+
return lq_compact!(A, CVᴴ, alg′)
158+
end
159+
function right_orth_polar!(A, CVᴴ, alg)
160+
alg′ = select_algorithm(right_polar!, A, alg)
161+
return right_polar!(A, CVᴴ, alg′)
162+
end
163+
function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing)
164+
alg′ = select_algorithm(svd_compact!, A, alg)
165+
U, S, Vᴴ′ = svd_compact!(A, alg′)
166+
C, Vᴴ = CVᴴ
167+
return mul!(C, U, S), copy!(Vᴴ, Vᴴ′)
168+
end
169+
function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing=nothing)
170+
alg′ = select_algorithm(svd_compact!, A, alg)
171+
C, Vᴴ = CVᴴ
172+
S = Diagonal(initialize_output(svd_vals!, A, alg′))
173+
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′)
174+
return rmul!(U, S), Vᴴ
175+
end
176+
function right_orth_svd!(A, CVᴴ, alg, trunc)
177+
alg′ = select_algorithm(svd_compact!, A, alg)
178+
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
179+
U, S, Vᴴ′ = svd_trunc!(A, alg_trunc)
180+
C, Vᴴ = CVᴴ
181+
return mul!(C, U, S), copy!(Vᴴ, Vᴴ′)
182+
end
183+
function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc)
184+
alg′ = select_algorithm(svd_compact!, A, alg)
185+
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
186+
C, Vᴴ = CVᴴ
187+
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
188+
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc)
189+
return rmul!(U, S), Vᴴ
190+
end
149191

150192
# Implementation of null functions
151193
# --------------------------------
@@ -159,56 +201,70 @@ function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothi
159201
return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc
160202
end
161203

162-
function left_null!(A::AbstractMatrix, N; trunc=nothing,
204+
function left_null!(A, N; trunc=nothing,
163205
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
164206
alg_svd=(;))
165207
check_input(left_null!, A, N)
166208
if !isnothing(trunc) && kind != :svd
167209
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
168210
end
169211
if kind == :qr
170-
alg_qr′ = select_algorithm(qr_null!, A, alg_qr)
171-
return qr_null!(A, N, alg_qr′)
172-
elseif kind == :svd && isnothing(trunc)
173-
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
174-
U, _, _ = svd_full!(A, alg_svd′)
175-
(m, n) = size(A)
176-
return copy!(N, view(U, 1:m, (n + 1):m))
212+
left_null_qr!(A, N, alg_qr)
177213
elseif kind == :svd
178-
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
179-
U, S, _ = svd_full!(A, alg_svd′)
180-
trunc′ = trunc isa TruncationStrategy ? trunc :
181-
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
182-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
183-
return truncate!(left_null!, (U, S), trunc′)
214+
left_null_svd!(A, N, alg_svd, trunc)
184215
else
185216
throw(ArgumentError("`left_null!` received unknown value `kind = $kind`"))
186217
end
187218
end
219+
function left_null_qr!(A, N, alg)
220+
alg′ = select_algorithm(qr_null!, A, alg)
221+
return qr_null!(A, N, alg′)
222+
end
223+
function left_null_svd!(A, N, alg, trunc::Nothing=nothing)
224+
alg′ = select_algorithm(svd_full!, A, alg)
225+
U, _, _ = svd_full!(A, alg′)
226+
(m, n) = size(A)
227+
return copy!(N, view(U, 1:m, (n + 1):m))
228+
end
229+
function left_null_svd!(A, N, alg, trunc)
230+
alg′ = select_algorithm(svd_full!, A, alg)
231+
U, S, _ = svd_full!(A, alg′)
232+
trunc′ = trunc isa TruncationStrategy ? trunc :
233+
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
234+
throw(ArgumentError("Unknown truncation strategy: $trunc"))
235+
return truncate!(left_null!, (U, S), trunc′)
236+
end
188237

189-
function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
238+
function right_null!(A, Nᴴ; trunc=nothing,
190239
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
191240
alg_svd=(;))
192241
check_input(right_null!, A, Nᴴ)
193242
if !isnothing(trunc) && kind != :svd
194243
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
195244
end
196245
if kind == :lq
197-
alg_lq′ = select_algorithm(lq_null!, A, alg_lq)
198-
return lq_null!(A, Nᴴ, alg_lq′)
199-
elseif kind == :svd && isnothing(trunc)
200-
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
201-
_, _, Vᴴ = svd_full!(A, alg_svd′)
202-
(m, n) = size(A)
203-
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
246+
return right_null_lq!(A, Nᴴ, alg_lq)
204247
elseif kind == :svd
205-
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
206-
_, S, Vᴴ = svd_full!(A, alg_svd′)
207-
trunc′ = trunc isa TruncationStrategy ? trunc :
208-
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
209-
throw(ArgumentError("Unknown truncation strategy: $trunc"))
210-
return truncate!(right_null!, (S, Vᴴ), trunc′)
248+
return right_null_svd!(A, Nᴴ, alg_svd, trunc)
211249
else
212250
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
213251
end
214252
end
253+
function right_null_lq!(A, Nᴴ, alg)
254+
alg′ = select_algorithm(lq_null!, A, alg)
255+
return lq_null!(A, Nᴴ, alg′)
256+
end
257+
function right_null_svd!(A, Nᴴ, alg, trunc::Nothing=nothing)
258+
alg′ = select_algorithm(svd_full!, A, alg)
259+
_, _, Vᴴ = svd_full!(A, alg′)
260+
(m, n) = size(A)
261+
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
262+
end
263+
function right_null_svd!(A, Nᴴ, alg, trunc)
264+
alg′ = select_algorithm(svd_full!, A, alg)
265+
_, S, Vᴴ = svd_full!(A, alg′)
266+
trunc′ = trunc isa TruncationStrategy ? trunc :
267+
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
268+
throw(ArgumentError("Unknown truncation strategy: $trunc"))
269+
return truncate!(right_null!, (S, Vᴴ), trunc′)
270+
end

src/interface/orthnull.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [
6969
"""
7070
function left_orth end
7171
function left_orth! end
72-
function left_orth!(A::AbstractMatrix; kwargs...)
72+
function left_orth!(A; kwargs...)
7373
return left_orth!(A, initialize_output(left_orth!, A); kwargs...)
7474
end
75-
function left_orth(A::AbstractMatrix; kwargs...)
75+
function left_orth(A; kwargs...)
7676
return left_orth!(copy_input(left_orth, A); kwargs...)
7777
end
7878

@@ -128,10 +128,10 @@ See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`r
128128
"""
129129
function right_orth end
130130
function right_orth! end
131-
function right_orth!(A::AbstractMatrix; kwargs...)
131+
function right_orth!(A; kwargs...)
132132
return right_orth!(A, initialize_output(right_orth!, A); kwargs...)
133133
end
134-
function right_orth(A::AbstractMatrix; kwargs...)
134+
function right_orth(A; kwargs...)
135135
return right_orth!(copy_input(right_orth, A); kwargs...)
136136
end
137137

@@ -180,10 +180,10 @@ See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), [
180180
"""
181181
function left_null end
182182
function left_null! end
183-
function left_null!(A::AbstractMatrix; kwargs...)
183+
function left_null!(A; kwargs...)
184184
return left_null!(A, initialize_output(left_null!, A); kwargs...)
185185
end
186-
function left_null(A::AbstractMatrix; kwargs...)
186+
function left_null(A; kwargs...)
187187
return left_null!(copy_input(left_null, A); kwargs...)
188188
end
189189

@@ -230,9 +230,9 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), [`r
230230
"""
231231
function right_null end
232232
function right_null! end
233-
function right_null!(A::AbstractMatrix; kwargs...)
233+
function right_null!(A; kwargs...)
234234
return right_null!(A, initialize_output(right_null!, A); kwargs...)
235235
end
236-
function right_null(A::AbstractMatrix; kwargs...)
236+
function right_null(A; kwargs...)
237237
return right_null!(copy_input(right_null, A); kwargs...)
238238
end

test/orthnull.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,53 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using StableRNGs
5-
using LinearAlgebra: LinearAlgebra, I
5+
using LinearAlgebra: LinearAlgebra, I, mul!
66
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
7+
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
8+
initialize_output
9+
10+
# Used to test non-AbstractMatrix codepaths.
11+
struct LinearMap{P<:AbstractMatrix}
12+
parent::P
13+
end
14+
Base.parent(A::LinearMap) = getfield(A, :parent)
15+
function Base.copy!(dest::LinearMap, src::LinearMap)
16+
copy!(parent(dest), parent(src))
17+
return dest
18+
end
19+
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
20+
mul!(parent(C), parent(A), parent(B))
21+
return C
22+
end
23+
24+
function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
25+
return LinearMap(copy_input(qr_compact, parent(A)))
26+
end
27+
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
28+
return LinearMap(copy_input(lq_compact, parent(A)))
29+
end
30+
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
31+
return LinearMap.(initialize_output(left_orth!, parent(A)))
32+
end
33+
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
34+
return LinearMap.(initialize_output(right_orth!, parent(A)))
35+
end
36+
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC)
37+
return check_input(left_orth!, parent(A), parent.(VC))
38+
end
39+
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC)
40+
return check_input(right_orth!, parent(A), parent.(VC))
41+
end
42+
function MatrixAlgebraKit.default_svd_algorithm(A::LinearMap)
43+
return default_svd_algorithm(parent(A))
44+
end
45+
function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap,
46+
alg::LAPACK_SVDAlgorithm)
47+
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
48+
end
49+
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::LAPACK_SVDAlgorithm)
50+
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
51+
end
752

853
@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32,
954
ComplexF64)
@@ -23,6 +68,10 @@ using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
2368
@test N' * N ≈ I
2469
@test V * V' + N * N' ≈ I
2570
71+
M = LinearMap(A)
72+
VM, CM = @constinferred left_orth(M; kind=:svd)
73+
@test parent(VM) * parent(CM) ≈ A
74+
2675
if m > n
2776
nullity = 5
2877
V, C = @constinferred left_orth(A)
@@ -162,6 +211,10 @@ end
162211
@test Nᴴ * Nᴴ' ≈ I
163212
@test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I
164213
214+
M = LinearMap(A)
215+
CM, VMᴴ = @constinferred right_orth(M; kind=:svd)
216+
@test parent(CM) * parent(VMᴴ) ≈ A
217+
165218
Ac = similar(A)
166219
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
167220
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)

0 commit comments

Comments
 (0)