Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 126 additions & 70 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Inputs
# ------
copy_input(::typeof(left_orth), A::AbstractMatrix) = copy_input(qr_compact, A) # do we ever need anything else
copy_input(::typeof(right_orth), A::AbstractMatrix) = copy_input(lq_compact, A) # do we ever need anything else
copy_input(::typeof(left_null), A::AbstractMatrix) = copy_input(qr_null, A) # do we ever need anything else
copy_input(::typeof(right_null), A::AbstractMatrix) = copy_input(lq_null, A) # do we ever need anything else
copy_input(::typeof(left_orth), A) = copy_input(qr_compact, A) # do we ever need anything else
copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever need anything else
copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else
copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else

function check_input(::typeof(left_orth!), A::AbstractMatrix, VC)
m, n = size(A)
Expand Down Expand Up @@ -81,71 +81,113 @@

# Implementation of orth functions
# --------------------------------
function left_orth!(A::AbstractMatrix, VC; trunc=nothing,
function left_orth!(A, VC; trunc=nothing,
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
alg_polar=(;), alg_svd=(;))
check_input(left_orth!, A, VC)
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for left_orth with kind=$kind"))
end
if kind == :qr
alg_qr′ = select_algorithm(qr_compact!, A, alg_qr)
return qr_compact!(A, VC, alg_qr′)
return left_orth_qr!(A, VC, alg_qr)
elseif kind == :polar
size(A, 1) >= size(A, 2) ||
throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`"))
alg_polar′ = select_algorithm(left_polar!, A, alg_polar)
return left_polar!(A, VC, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′)
return U, lmul!(S, Vᴴ)
return left_orth_polar!(A, VC, alg_polar)
elseif kind == :svd
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc)
return U, lmul!(S, Vᴴ)
return left_orth_svd!(A, VC, alg_svd, trunc)
else
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
end
end
function left_orth_qr!(A, VC, alg)
alg′ = select_algorithm(qr_compact!, A, alg)
return qr_compact!(A, VC, alg′)
end
function left_orth_polar!(A, VC, alg)
alg′ = select_algorithm(left_polar!, A, alg)
return left_polar!(A, VC, alg′)
end
function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing)
alg′ = select_algorithm(svd_compact!, A, alg)
U, S, Vᴴ = svd_compact!(A, alg′)
V, C = VC
return copy!(V, U), mul!(C, S, Vᴴ)
end
function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing=nothing)
alg′ = select_algorithm(svd_compact!, A, alg)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg′))
Copy link
Member

@Jutho Jutho May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am open to omitting the AbstractMatrix restriction above, but then his part of the implementation (namely the use of Diagonal) is again very much (Abstract)Matrix specific, whereas left_orth_qr! and left_orth_polar! are quite generic and probably work for other types as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a more general implementation can be added that is less economic with its memory usage:

function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing)
    alg′ = select_algorithm(svd_compact!, A, alg)
    U, S, Vᴴ = svd_compact!(A, alg′)
    V, C = VC
    return copy!(V, U), mul!(C, S, Vᴴ)
end

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I wasn't being careful about the type restriction in this case. I think that's a good suggestion to have a more generic version.

Another thing I can think of is replacing Diagonal with diagonal, and then diagonal could be an interface function that types can overload to construct a diagonal matrix-like object. But I like your suggestion better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, internally there is already the diagview to have a uniform interface to extracting a view of the diagonal of both Matrix and Diagonal, but that could also be specialized for other types. So a type-agnostic construction of a diagonal matrix could also be useful. With diagonal as name, it might get confusing though, as LinearAlgebra.diag is exactly the opposite, namely extracting the diagonal of a matrix and returning it as a vector. But I wouldn't have another naming suggestion.

I am also not sure if we generally want that the output type of svd_vals can be used to construct the S output of svd_compact. In TensorKit tensors, for example, a DiagonalTensorMap would still store all the singular values in a single list (Vector), but with internal structure such that it is known which parts of this vector are associated with which sectors/quantum numbers. svd_vals would than rather return that information as a Dict where for every sector (being the keys into the dict) there is a separate vector with only the singular values for that sector.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it definitely would require some work to make the design based on diagview/diagonal work for exotic types like that.

I've updated the code to add a non-AbstractMatrix codepath based on your suggestion above. I also added a test that defines a simple non-AbstractMatrix to test that codepath.

U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′)
return U, lmul!(S, Vᴴ)
end
function left_orth_svd!(A, VC, alg, trunc)
alg′ = select_algorithm(svd_compact!, A, alg)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
U, S, Vᴴ = svd_trunc!(A, alg_trunc)
V, C = VC
return copy!(V, U), mul!(C, S, Vᴴ)

Check warning on line 127 in src/implementations/orthnull.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/orthnull.jl#L122-L127

Added lines #L122 - L127 were not covered by tests
end
function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc)
alg′ = select_algorithm(svd_compact!, A, alg)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc)
return U, lmul!(S, Vᴴ)
end

function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing,
function right_orth!(A, CVᴴ; trunc=nothing,
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
alg_polar=(;), alg_svd=(;))
check_input(right_orth!, A, CVᴴ)
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for right_orth with kind=$kind"))
end
if kind == :lq
alg_lq′ = select_algorithm(lq_compact!, A, alg_lq)
return lq_compact!(A, CVᴴ, alg_lq′)
return right_orth_lq!(A, CVᴴ, alg_lq)
elseif kind == :polar
size(A, 2) >= size(A, 1) ||
throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`"))
alg_polar′ = select_algorithm(right_polar!, A, alg_polar)
return right_polar!(A, CVᴴ, alg_polar′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_svd′))
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′)
return rmul!(U, S), Vᴴ
return right_orth_polar!(A, CVᴴ, alg_polar)
elseif kind == :svd
alg_svd′ = select_algorithm(svd_compact!, A, alg_svd)
alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc)
return rmul!(U, S), Vᴴ
return right_orth_svd!(A, CVᴴ, alg_svd, trunc)
else
throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`"))
end
end
function right_orth_lq!(A, CVᴴ, alg)
alg′ = select_algorithm(lq_compact!, A, alg)
return lq_compact!(A, CVᴴ, alg′)
end
function right_orth_polar!(A, CVᴴ, alg)
alg′ = select_algorithm(right_polar!, A, alg)
return right_polar!(A, CVᴴ, alg′)
end
function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing)
alg′ = select_algorithm(svd_compact!, A, alg)
U, S, Vᴴ′ = svd_compact!(A, alg′)
C, Vᴴ = CVᴴ
return mul!(C, U, S), copy!(Vᴴ, Vᴴ′)
end
function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing=nothing)
alg′ = select_algorithm(svd_compact!, A, alg)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg′))
U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′)
return rmul!(U, S), Vᴴ
end
function right_orth_svd!(A, CVᴴ, alg, trunc)
alg′ = select_algorithm(svd_compact!, A, alg)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
U, S, Vᴴ′ = svd_trunc!(A, alg_trunc)
C, Vᴴ = CVᴴ
return mul!(C, U, S), copy!(Vᴴ, Vᴴ′)

Check warning on line 181 in src/implementations/orthnull.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/orthnull.jl#L176-L181

Added lines #L176 - L181 were not covered by tests
end
function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc)
alg′ = select_algorithm(svd_compact!, A, alg)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc)
return rmul!(U, S), Vᴴ
end

# Implementation of null functions
# --------------------------------
Expand All @@ -159,56 +201,70 @@
return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc
end

function left_null!(A::AbstractMatrix, N; trunc=nothing,
function left_null!(A, N; trunc=nothing,
kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true),
alg_svd=(;))
check_input(left_null!, A, N)
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
end
if kind == :qr
alg_qr′ = select_algorithm(qr_null!, A, alg_qr)
return qr_null!(A, N, alg_qr′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
U, _, _ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(N, view(U, 1:m, (n + 1):m))
left_null_qr!(A, N, alg_qr)
elseif kind == :svd
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
U, S, _ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(left_null!, (U, S), trunc′)
left_null_svd!(A, N, alg_svd, trunc)
else
throw(ArgumentError("`left_null!` received unknown value `kind = $kind`"))
end
end
function left_null_qr!(A, N, alg)
alg′ = select_algorithm(qr_null!, A, alg)
return qr_null!(A, N, alg′)
end
function left_null_svd!(A, N, alg, trunc::Nothing=nothing)
alg′ = select_algorithm(svd_full!, A, alg)
U, _, _ = svd_full!(A, alg′)
(m, n) = size(A)
return copy!(N, view(U, 1:m, (n + 1):m))
end
function left_null_svd!(A, N, alg, trunc)
alg′ = select_algorithm(svd_full!, A, alg)
U, S, _ = svd_full!(A, alg′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(left_null!, (U, S), trunc′)
end

function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing,
function right_null!(A, Nᴴ; trunc=nothing,
kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true),
alg_svd=(;))
check_input(right_null!, A, Nᴴ)
if !isnothing(trunc) && kind != :svd
throw(ArgumentError("truncation not supported for right_null with kind=$kind"))
end
if kind == :lq
alg_lq′ = select_algorithm(lq_null!, A, alg_lq)
return lq_null!(A, Nᴴ, alg_lq′)
elseif kind == :svd && isnothing(trunc)
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
_, _, Vᴴ = svd_full!(A, alg_svd′)
(m, n) = size(A)
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
return right_null_lq!(A, Nᴴ, alg_lq)
elseif kind == :svd
alg_svd′ = select_algorithm(svd_full!, A, alg_svd)
_, S, Vᴴ = svd_full!(A, alg_svd′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(right_null!, (S, Vᴴ), trunc′)
return right_null_svd!(A, Nᴴ, alg_svd, trunc)
else
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
end
end
function right_null_lq!(A, Nᴴ, alg)
alg′ = select_algorithm(lq_null!, A, alg)
return lq_null!(A, Nᴴ, alg′)
end
function right_null_svd!(A, Nᴴ, alg, trunc::Nothing=nothing)
alg′ = select_algorithm(svd_full!, A, alg)
_, _, Vᴴ = svd_full!(A, alg′)
(m, n) = size(A)
return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n))
end
function right_null_svd!(A, Nᴴ, alg, trunc)
alg′ = select_algorithm(svd_full!, A, alg)
_, S, Vᴴ = svd_full!(A, alg′)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(right_null!, (S, Vᴴ), trunc′)
end
16 changes: 8 additions & 8 deletions src/interface/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [
"""
function left_orth end
function left_orth! end
function left_orth!(A::AbstractMatrix; kwargs...)
function left_orth!(A; kwargs...)
return left_orth!(A, initialize_output(left_orth!, A); kwargs...)
end
function left_orth(A::AbstractMatrix; kwargs...)
function left_orth(A; kwargs...)
return left_orth!(copy_input(left_orth, A); kwargs...)
end

Expand Down Expand Up @@ -128,10 +128,10 @@ See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`r
"""
function right_orth end
function right_orth! end
function right_orth!(A::AbstractMatrix; kwargs...)
function right_orth!(A; kwargs...)
return right_orth!(A, initialize_output(right_orth!, A); kwargs...)
end
function right_orth(A::AbstractMatrix; kwargs...)
function right_orth(A; kwargs...)
return right_orth!(copy_input(right_orth, A); kwargs...)
end

Expand Down Expand Up @@ -180,10 +180,10 @@ See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), [
"""
function left_null end
function left_null! end
function left_null!(A::AbstractMatrix; kwargs...)
function left_null!(A; kwargs...)
return left_null!(A, initialize_output(left_null!, A); kwargs...)
end
function left_null(A::AbstractMatrix; kwargs...)
function left_null(A; kwargs...)
return left_null!(copy_input(left_null, A); kwargs...)
end

Expand Down Expand Up @@ -230,9 +230,9 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), [`r
"""
function right_null end
function right_null! end
function right_null!(A::AbstractMatrix; kwargs...)
function right_null!(A; kwargs...)
return right_null!(A, initialize_output(right_null!, A); kwargs...)
end
function right_null(A::AbstractMatrix; kwargs...)
function right_null(A; kwargs...)
return right_null!(copy_input(right_null, A); kwargs...)
end
55 changes: 54 additions & 1 deletion test/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,53 @@ using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, I
using LinearAlgebra: LinearAlgebra, I, mul!
using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm,
initialize_output

# Used to test non-AbstractMatrix codepaths.
struct LinearMap{P<:AbstractMatrix}
parent::P
end
Base.parent(A::LinearMap) = getfield(A, :parent)
function Base.copy!(dest::LinearMap, src::LinearMap)
copy!(parent(dest), parent(src))
return dest
end
function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap)
mul!(parent(C), parent(A), parent(B))
return C
end

function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap)
return LinearMap(copy_input(qr_compact, parent(A)))
end
function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap)
return LinearMap(copy_input(lq_compact, parent(A)))
end
function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap)
return LinearMap.(initialize_output(left_orth!, parent(A)))
end
function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap)
return LinearMap.(initialize_output(right_orth!, parent(A)))
end
function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC)
return check_input(left_orth!, parent(A), parent.(VC))
end
function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC)
return check_input(right_orth!, parent(A), parent.(VC))
end
function MatrixAlgebraKit.default_svd_algorithm(A::LinearMap)
return default_svd_algorithm(parent(A))
end
function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap,
alg::LAPACK_SVDAlgorithm)
return LinearMap.(initialize_output(svd_compact!, parent(A), alg))
end
function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::LAPACK_SVDAlgorithm)
return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg))
end

@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32,
ComplexF64)
Expand All @@ -23,6 +68,10 @@ using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow
@test N' * N ≈ I
@test V * V' + N * N' ≈ I

M = LinearMap(A)
VM, CM = @constinferred left_orth(M; kind=:svd)
@test parent(VM) * parent(CM) ≈ A

if m > n
nullity = 5
V, C = @constinferred left_orth(A)
Expand Down Expand Up @@ -162,6 +211,10 @@ end
@test Nᴴ * Nᴴ' ≈ I
@test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I

M = LinearMap(A)
CM, VMᴴ = @constinferred right_orth(M; kind=:svd)
@test parent(CM) * parent(VMᴴ) ≈ A

Ac = similar(A)
C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ))
Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)
Expand Down
Loading