Skip to content

Commit c52614b

Browse files
authored
Refactor orthogonalization and nullspace interface (#79)
* `left_orth` and `right_orth` with `@functiondef` * refactor truncationintersection for type stability * factor out `linearmap.jl` * refactor orth algorithm selection * add algorithm traits * refactor left_orth and right_orth implementations * refactor null algorithm selection * refactor null implementation * more algorithm traits * `left_null` and `right_null` with `@functiondef` * reorganize algorithm unions * refactor null truncation * update docstrings * disambiguate alg selection * update tests * more docs updates * mark randomized SVD as unusable for nullspaces * fix JET complaint * docstring reorganization * improve truncation * headers * fix merge * work out alternate proposition * unpack algorithms * rework traits * also include null implementation * maybeblasmat * fix docs build * some cleanup * fix type stability again * update gpu tests * address some review comments * some AMD fixes * some more AMD fixes * more more AMD fixes * move lqviatransposedqr * no randomized svd for null * no initialization for orthnull with SVD * fix docstring * more more more AMD fixes * Revert "no randomized svd for null" This reverts commit 185d939. * update syntax for leftorth * update algselector docstrings * migrate logic from constructors to functions * also update docs * remove unnecessary overloads * apply suggestions
1 parent 4fbc3bf commit c52614b

File tree

21 files changed

+1349
-914
lines changed

21 files changed

+1349
-914
lines changed

docs/src/user_interface/decompositions.md

Lines changed: 187 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,205 @@ PolarNewton
169169

170170
## Orthogonal Subspaces
171171

172-
Often it is useful to compute orthogonal bases for a particular subspace defined by a matrix.
173-
Given a matrix `A` we can compute an orthonormal basis for its image or coimage, and factorize the matrix accordingly.
172+
Often it is useful to compute orthogonal bases for particular subspaces defined by a matrix.
173+
Given a matrix `A`, we can compute an orthonormal basis for its image or coimage, and factorize the matrix accordingly.
174174
These bases are accessible through [`left_orth`](@ref) and [`right_orth`](@ref) respectively.
175-
This is implemented through a combination of the decompositions mentioned above, and serves as a convenient interface to these operations.
175+
176+
### Overview
177+
178+
The [`left_orth`](@ref) function computes an orthonormal basis `V` for the image (column space) of `A`, along with a corestriction matrix `C` such that `A = V * C`.
179+
The resulting `V` has orthonormal columns (`V' * V ≈ I` or `isisometric(V)`).
180+
181+
Similarly, [`right_orth`](@ref) computes an orthonormal basis for the coimage (row space) of `A`, i.e., the image of `A'`.
182+
It returns matrices `C` and `Vᴴ` such that `A = C * Vᴴ`, where `V = (Vᴴ)'` has orthonormal columns (`isisometric(Vᴴ; side = :right)`).
183+
184+
These functions serve as high-level interfaces that automatically select the most appropriate decomposition based on the specified options, making them convenient for users who want orthonormalization without worrying about the underlying implementation details.
176185

177186
```@docs; canonical=false
178187
left_orth
179188
right_orth
180189
```
181190

191+
### Algorithm Selection
192+
193+
Both functions support multiple decomposition drivers, which can be selected through the `alg` keyword argument:
194+
195+
**For `left_orth`:**
196+
- `alg = :qr` (default without truncation): Uses QR decomposition via [`qr_compact`](@ref)
197+
- `alg = :polar`: Uses polar decomposition via [`left_polar`](@ref)
198+
- `alg = :svd` (default with truncation): Uses SVD via [`svd_compact`](@ref) or [`svd_trunc`](@ref)
199+
200+
**For `right_orth`:**
201+
- `alg = :lq` (default without truncation): Uses LQ decomposition via [`lq_compact`](@ref)
202+
- `alg = :polar`: Uses polar decomposition via [`right_polar`](@ref)
203+
- `alg = :svd` (default with truncation): Uses SVD via [`svd_compact`](@ref) or [`svd_trunc`](@ref)
204+
205+
When `alg` is not specified, the function automatically selects `:qr`/`:lq` for exact orthogonalization, or `:svd` when a truncation strategy is provided.
206+
207+
### Extending with Custom Algorithms
208+
209+
To register a custom algorithm type for use with these functions, you need to define the appropriate conversion function, for example:
210+
211+
```julia
212+
# For left_orth
213+
MatrixAlgebraKit.left_orth_alg(alg::MyCustomAlgorithm) = LeftOrthAlgorithm{:qr}(alg)
214+
215+
# For right_orth
216+
MatrixAlgebraKit.right_orth_alg(alg::MyCustomAlgorithm) = RightOrthAlgorithm{:lq}(alg)
217+
```
218+
219+
The type parameter (`:qr`, `:lq`, `:polar`, or `:svd`) indicates which factorization backend will be used.
220+
The wrapper algorithm types handle the dispatch to the appropriate implementation:
221+
222+
```@docs; canonical=false
223+
left_orth_alg
224+
right_orth_alg
225+
LeftOrthAlgorithm
226+
RightOrthAlgorithm
227+
```
228+
229+
### Examples
230+
231+
Basic orthogonalization:
232+
233+
```jldoctest orthnull; output=false
234+
using MatrixAlgebraKit
235+
using LinearAlgebra
236+
237+
A = [1.0 2.0; 3.0 4.0; 5.0 6.0]
238+
V, C = left_orth(A)
239+
(V' * V) ≈ I && A ≈ V * C
240+
241+
# output
242+
true
243+
```
244+
245+
Using different algorithms:
246+
247+
```jldoctest orthnull; output=false
248+
A = randn(4, 3)
249+
V1, C1 = left_orth(A; alg = :qr)
250+
V2, C2 = left_orth(A; alg = :polar)
251+
V3, C3 = left_orth(A; alg = :svd)
252+
A ≈ V1 * C1 ≈ V2 * C2 ≈ V3 * C3
253+
254+
# output
255+
true
256+
```
257+
258+
With truncation:
259+
260+
```jldoctest orthnull; output=false
261+
A = [1.0 0.0; 0.0 1e-10; 0.0 0.0]
262+
V, C = left_orth(A; trunc = (atol = 1e-8,))
263+
size(V, 2) == 1 # Only one column retained
264+
265+
# output
266+
true
267+
```
268+
269+
182270
## Null Spaces
183271

184272
Similarly, it can be convenient to obtain an orthogonal basis for the kernel or cokernel of a matrix.
185-
These are the compliments of the coimage and image, respectively, and can be computed using the [`left_null`](@ref) and [`right_null`](@ref) functions.
186-
Again, this is typically implemented through a combination of the decompositions mentioned above, and serves as a convenient interface to these operations.
273+
These are the orthogonal complements of the coimage and image, respectively, and can be computed using the [`left_null`](@ref) and [`right_null`](@ref) functions.
274+
275+
### Overview
276+
277+
The [`left_null`](@ref) function computes an orthonormal basis `N` for the cokernel (left nullspace) of `A`, which is the nullspace of `A'`.
278+
This means `A' * N ≈ 0` and `N' * N ≈ I`.
279+
280+
Similarly, [`right_null`](@ref) computes an orthonormal basis for the kernel (right nullspace) of `A`.
281+
It returns `Nᴴ` such that `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`, where `N = (Nᴴ)'` has orthonormal columns.
282+
283+
These functions automatically handle rank determination and provide convenient access to nullspace computation without requiring detailed knowledge of the underlying decomposition methods.
187284

188285
```@docs; canonical=false
189286
left_null
190287
right_null
191288
```
289+
290+
### Algorithm Selection
291+
292+
Both functions support multiple decomposition drivers, which can be selected through the `alg` keyword argument:
293+
294+
**For `left_null`:**
295+
- `alg = :qr` (default without truncation): Uses QR-based nullspace computation via [`qr_null`](@ref)
296+
- `alg = :svd` (default with truncation): Uses SVD via [`svd_full`](@ref) with appropriate truncation
297+
298+
**For `right_null`:**
299+
- `alg = :lq` (default without truncation): Uses LQ-based nullspace computation via [`lq_null`](@ref)
300+
- `alg = :svd` (default with truncation): Uses SVD via [`svd_full`](@ref) with appropriate truncation
301+
302+
When `alg` is not specified, the function automatically selects `:qr`/`:lq` for exact nullspace computation, or `:svd` when a truncation strategy is provided to handle numerical rank determination.
303+
304+
!!! note
305+
For nullspace functions, [`notrunc`](@ref) has special meaning when used with the default QR/LQ algorithms.
306+
It indicates that the nullspace should be computed from the exact zeros determined by the additional rows/columns of the extended matrix, without any tolerance-based truncation.
307+
308+
### Extending with Custom Algorithms
309+
310+
To register a custom algorithm type for use with these functions, you need to define the appropriate conversion function:
311+
312+
```julia
313+
# For left_null
314+
MatrixAlgebraKit.left_null_alg(alg::MyCustomAlgorithm) = LeftNullAlgorithm{:qr}(alg)
315+
316+
# For right_null
317+
MatrixAlgebraKit.right_null_alg(alg::MyCustomAlgorithm) = RightNullAlgorithm{:lq}(alg)
318+
```
319+
320+
The type parameter (`:qr`, `:lq`, or `:svd`) indicates which factorization backend will be used.
321+
The wrapper algorithm types handle the dispatch to the appropriate implementation:
322+
323+
```@docs; canonical=false
324+
LeftNullAlgorithm
325+
RightNullAlgorithm
326+
left_null_alg
327+
right_null_alg
328+
```
329+
330+
### Examples
331+
332+
Basic nullspace computation:
333+
334+
```jldoctest orthnull; output=false
335+
A = [1.0 2.0 3.0; 4.0 5.0 6.0] # Rank 2 matrix
336+
N = left_null(A)
337+
size(N) == (2, 0)
338+
339+
# output
340+
true
341+
```
342+
343+
```jldoctest orthnull; output=false
344+
Nᴴ = right_null(A)
345+
size(Nᴴ) == (1, 3) && norm(A * Nᴴ') < 1e-14 && isisometric(Nᴴ; side = :right)
346+
347+
# output
348+
true
349+
```
350+
351+
Computing nullspace with rank detection:
352+
353+
```jldoctest orthnull; output=false
354+
A = [1.0 2.0; 2.0 4.0; 3.0 6.0] # Rank 1 matrix (second column = 2*first)
355+
N = left_null(A; alg = :svd, trunc = (atol = 1e-10,))
356+
size(N) == (3, 2) && norm(A' * N) < 1e-12 && isisometric(N)
357+
358+
# output
359+
true
360+
```
361+
362+
Using different algorithms:
363+
364+
```jldoctest orthnull; output=false
365+
A = [1.0 0.0 0.0; 0.0 1.0 0.0]
366+
N1 = right_null(A; alg = :lq)
367+
N2 = right_null(A; alg = :svd)
368+
norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 &&
369+
isisometric(N1; side = :right) && isisometric(N2; side = :right)
370+
371+
# output
372+
true
373+
```

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MatrixAlgebraKit
44
using MatrixAlgebraKit: @algdef, Algorithm, check_input
55
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
66
using MatrixAlgebraKit: diagview, sign_safe
7-
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
7+
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
88
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
@@ -161,7 +161,9 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
161161
return A, B
162162
end
163163

164-
function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tuple{TU, TS}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TU <: ROCArray, TS}
164+
function MatrixAlgebraKit.truncate(
165+
::typeof(left_null!), US::Tuple{TU, TS}, strategy::TruncationStrategy
166+
) where {TU <: ROCMatrix, TS}
165167
# TODO: avoid allocation?
166168
U, S = US
167169
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
@@ -170,5 +172,32 @@ function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tu
170172
Utrunc = U[:, trunc_cols]
171173
return Utrunc, ind
172174
end
175+
function MatrixAlgebraKit.truncate(
176+
::typeof(right_null!), SVᴴ::Tuple{TS, TVᴴ}, strategy::TruncationStrategy
177+
) where {TS, TVᴴ <: ROCMatrix}
178+
# TODO: avoid allocation?
179+
S, Vᴴ = SVᴴ
180+
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1))))
181+
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
182+
trunc_rows = collect(1:size(Vᴴ, 1))[ind]
183+
Vᴴtrunc = Vᴴ[trunc_rows, :]
184+
return Vᴴtrunc, ind
185+
end
186+
187+
# disambiguate:
188+
function MatrixAlgebraKit.truncate(
189+
::typeof(left_null!), (U, S)::Tuple{TU, TS}, ::NoTruncation
190+
) where {TU <: ROCMatrix, TS}
191+
m, n = size(S)
192+
ind = (n + 1):m
193+
return U[:, ind], ind
194+
end
195+
function MatrixAlgebraKit.truncate(
196+
::typeof(right_null!), (S, Vᴴ)::Tuple{TS, TVᴴ}, ::NoTruncation
197+
) where {TS, TVᴴ <: ROCMatrix}
198+
m, n = size(S)
199+
ind = (m + 1):n
200+
return Vᴴ[ind, :], ind
201+
end
173202

174203
end

src/algorithms.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ function _show_alg(io::IO, alg::Algorithm)
5353
return print(io, ")")
5454
end
5555

56+
# Algorithm traits
57+
# ----------------
58+
"""
59+
does_truncate(alg::AbstractAlgorithm) -> Bool
60+
61+
Indicate whether or not an algorithm will compute a truncated decomposition
62+
(such that composing the factors only approximates the input up to some tolerance).
63+
"""
64+
does_truncate(alg::AbstractAlgorithm) = false
65+
66+
# Algorithm selection
67+
# -------------------
5668
@doc """
5769
MatrixAlgebraKit.select_algorithm(f, A, alg::AbstractAlgorithm)
5870
MatrixAlgebraKit.select_algorithm(f, A, alg::Symbol; kwargs...)
@@ -83,7 +95,7 @@ function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
8395
return Algorithm{alg}(; kwargs...)
8496
elseif alg isa Type
8597
return alg(; kwargs...)
86-
elseif alg isa NamedTuple
98+
elseif alg isa NamedTuple || alg isa Base.Pairs
8799
isempty(kwargs) ||
88100
throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified."))
89101
return default_algorithm(f, A; alg...)
@@ -160,6 +172,24 @@ function select_truncation(trunc)
160172
end
161173
end
162174

175+
@doc """
176+
MatrixAlgebraKit.select_null_truncation(trunc)
177+
178+
Construct a [`TruncationStrategy`](@ref) from the given `NamedTuple` of keywords or input strategy, to implement a nullspace selection.
179+
""" select_null_truncation
180+
181+
function select_null_truncation(trunc)
182+
if isnothing(trunc)
183+
return NoTruncation()
184+
elseif trunc isa NamedTuple
185+
return null_truncation_strategy(; trunc...)
186+
elseif trunc isa TruncationStrategy
187+
return trunc
188+
else
189+
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
190+
end
191+
end
192+
163193
@doc """
164194
MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)
165195
@@ -200,6 +230,8 @@ struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm
200230
trunc::T
201231
end
202232

233+
does_truncate(::TruncatedAlgorithm) = true
234+
203235
# Utility macros
204236
# --------------
205237

@@ -230,14 +262,14 @@ end
230262

231263
function _arg_expr(::Val{1}, f, f!)
232264
return quote # out of place to inplace
233-
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
265+
@inline $f(A; alg = nothing, kwargs...) = $f(A, select_algorithm($f, A, alg; kwargs...))
234266
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
235267

236268
# fill in arguments
237-
function $f!(A; alg = nothing, kwargs...)
269+
@inline function $f!(A; alg = nothing, kwargs...)
238270
return $f!(A, select_algorithm($f!, A, alg; kwargs...))
239271
end
240-
function $f!(A, out; alg = nothing, kwargs...)
272+
@inline function $f!(A, out; alg = nothing, kwargs...)
241273
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
242274
end
243275
function $f!(A, alg::AbstractAlgorithm)

0 commit comments

Comments
 (0)