diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 47626e12..f53055db 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -1,9 +1,9 @@ # Inputs # ------ -copy_input(::typeof(left_orth), A::AbstractMatrix) = copy_input(qr_compact, A) # do we ever need anything else -copy_input(::typeof(right_orth), A::AbstractMatrix) = copy_input(lq_compact, A) # do we ever need anything else -copy_input(::typeof(left_null), A::AbstractMatrix) = copy_input(qr_null, A) # do we ever need anything else -copy_input(::typeof(right_null), A::AbstractMatrix) = copy_input(lq_null, A) # do we ever need anything else +copy_input(::typeof(left_orth), A) = copy_input(qr_compact, A) # do we ever need anything else +copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever need anything else +copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else +copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else function check_input(::typeof(left_orth!), A::AbstractMatrix, VC) m, n = size(A) @@ -81,7 +81,7 @@ end # Implementation of orth functions # -------------------------------- -function left_orth!(A::AbstractMatrix, VC; trunc=nothing, +function left_orth!(A, VC; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_polar=(;), alg_svd=(;)) check_input(left_orth!, A, VC) @@ -89,32 +89,53 @@ function left_orth!(A::AbstractMatrix, VC; trunc=nothing, throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) end if kind == :qr - alg_qr′ = select_algorithm(qr_compact!, A, alg_qr) - return qr_compact!(A, VC, alg_qr′) + return left_orth_qr!(A, VC, alg_qr) elseif kind == :polar - size(A, 1) >= size(A, 2) || - throw(ArgumentError("`left_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m >= n`")) - alg_polar′ = select_algorithm(left_polar!, A, alg_polar) - return left_polar!(A, VC, alg_polar′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) - U, S, Vᴴ = svd_compact!(A, (V, S, C), alg_svd′) - return U, lmul!(S, Vᴴ) + return left_orth_polar!(A, VC, alg_polar) elseif kind == :svd - alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) - alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_svd_trunc) - return U, lmul!(S, Vᴴ) + return left_orth_svd!(A, VC, alg_svd, trunc) else throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) end end +function left_orth_qr!(A, VC, alg) + alg′ = select_algorithm(qr_compact!, A, alg) + return qr_compact!(A, VC, alg′) +end +function left_orth_polar!(A, VC, alg) + alg′ = select_algorithm(left_polar!, A, alg) + return left_polar!(A, VC, alg′) +end +function left_orth_svd!(A, VC, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + U, S, Vᴴ = svd_compact!(A, alg′) + V, C = VC + return copy!(V, U), mul!(C, S, Vᴴ) +end +function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + V, C = VC + S = Diagonal(initialize_output(svd_vals!, A, alg′)) + U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′) + return U, lmul!(S, Vᴴ) +end +function left_orth_svd!(A, VC, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + U, S, Vᴴ = svd_trunc!(A, alg_trunc) + V, C = VC + return copy!(V, U), mul!(C, S, Vᴴ) +end +function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + V, C = VC + S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc) + return U, lmul!(S, Vᴴ) +end -function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, +function right_orth!(A, CVᴴ; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), alg_polar=(;), alg_svd=(;)) check_input(right_orth!, A, CVᴴ) @@ -122,30 +143,51 @@ function right_orth!(A::AbstractMatrix, CVᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) end if kind == :lq - alg_lq′ = select_algorithm(lq_compact!, A, alg_lq) - return lq_compact!(A, CVᴴ, alg_lq′) + return right_orth_lq!(A, CVᴴ, alg_lq) elseif kind == :polar - size(A, 2) >= size(A, 1) || - throw(ArgumentError("`right_orth!` with `kind = :polar` only possible for `(m, n)` matrix with `m <= n`")) - alg_polar′ = select_algorithm(right_polar!, A, alg_polar) - return right_polar!(A, CVᴴ, alg_polar′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_svd′)) - U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg_svd′) - return rmul!(U, S), Vᴴ + return right_orth_polar!(A, CVᴴ, alg_polar) elseif kind == :svd - alg_svd′ = select_algorithm(svd_compact!, A, alg_svd) - alg_svd_trunc = select_algorithm(svd_trunc!, A, alg_svd′; trunc) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_svd_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_svd_trunc) - return rmul!(U, S), Vᴴ + return right_orth_svd!(A, CVᴴ, alg_svd, trunc) else throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`")) end end +function right_orth_lq!(A, CVᴴ, alg) + alg′ = select_algorithm(lq_compact!, A, alg) + return lq_compact!(A, CVᴴ, alg′) +end +function right_orth_polar!(A, CVᴴ, alg) + alg′ = select_algorithm(right_polar!, A, alg) + return right_polar!(A, CVᴴ, alg′) +end +function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + U, S, Vᴴ′ = svd_compact!(A, alg′) + C, Vᴴ = CVᴴ + return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) +end +function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_compact!, A, alg) + C, Vᴴ = CVᴴ + S = Diagonal(initialize_output(svd_vals!, A, alg′)) + U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′) + return rmul!(U, S), Vᴴ +end +function right_orth_svd!(A, CVᴴ, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + U, S, Vᴴ′ = svd_trunc!(A, alg_trunc) + C, Vᴴ = CVᴴ + return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) +end +function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc) + alg′ = select_algorithm(svd_compact!, A, alg) + alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) + C, Vᴴ = CVᴴ + S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) + U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc) + return rmul!(U, S), Vᴴ +end # Implementation of null functions # -------------------------------- @@ -159,7 +201,7 @@ function null_truncation_strategy(; atol=nothing, rtol=nothing, maxnullity=nothi return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev=false) : trunc end -function left_null!(A::AbstractMatrix, N; trunc=nothing, +function left_null!(A, N; trunc=nothing, kind=isnothing(trunc) ? :qr : :svd, alg_qr=(; positive=true), alg_svd=(;)) check_input(left_null!, A, N) @@ -167,26 +209,33 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing, throw(ArgumentError("truncation not supported for left_null with kind=$kind")) end if kind == :qr - alg_qr′ = select_algorithm(qr_null!, A, alg_qr) - return qr_null!(A, N, alg_qr′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = select_algorithm(svd_full!, A, alg_svd) - U, _, _ = svd_full!(A, alg_svd′) - (m, n) = size(A) - return copy!(N, view(U, 1:m, (n + 1):m)) + left_null_qr!(A, N, alg_qr) elseif kind == :svd - alg_svd′ = select_algorithm(svd_full!, A, alg_svd) - U, S, _ = svd_full!(A, alg_svd′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(left_null!, (U, S), trunc′) + left_null_svd!(A, N, alg_svd, trunc) else throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) end end +function left_null_qr!(A, N, alg) + alg′ = select_algorithm(qr_null!, A, alg) + return qr_null!(A, N, alg′) +end +function left_null_svd!(A, N, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_full!, A, alg) + U, _, _ = svd_full!(A, alg′) + (m, n) = size(A) + return copy!(N, view(U, 1:m, (n + 1):m)) +end +function left_null_svd!(A, N, alg, trunc) + alg′ = select_algorithm(svd_full!, A, alg) + U, S, _ = svd_full!(A, alg′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(left_null!, (U, S), trunc′) +end -function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, +function right_null!(A, Nᴴ; trunc=nothing, kind=isnothing(trunc) ? :lq : :svd, alg_lq=(; positive=true), alg_svd=(;)) check_input(right_null!, A, Nᴴ) @@ -194,21 +243,28 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, throw(ArgumentError("truncation not supported for right_null with kind=$kind")) end if kind == :lq - alg_lq′ = select_algorithm(lq_null!, A, alg_lq) - return lq_null!(A, Nᴴ, alg_lq′) - elseif kind == :svd && isnothing(trunc) - alg_svd′ = select_algorithm(svd_full!, A, alg_svd) - _, _, Vᴴ = svd_full!(A, alg_svd′) - (m, n) = size(A) - return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) + return right_null_lq!(A, Nᴴ, alg_lq) elseif kind == :svd - alg_svd′ = select_algorithm(svd_full!, A, alg_svd) - _, S, Vᴴ = svd_full!(A, alg_svd′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(right_null!, (S, Vᴴ), trunc′) + return right_null_svd!(A, Nᴴ, alg_svd, trunc) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end end +function right_null_lq!(A, Nᴴ, alg) + alg′ = select_algorithm(lq_null!, A, alg) + return lq_null!(A, Nᴴ, alg′) +end +function right_null_svd!(A, Nᴴ, alg, trunc::Nothing=nothing) + alg′ = select_algorithm(svd_full!, A, alg) + _, _, Vᴴ = svd_full!(A, alg′) + (m, n) = size(A) + return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) +end +function right_null_svd!(A, Nᴴ, alg, trunc) + alg′ = select_algorithm(svd_full!, A, alg) + _, S, Vᴴ = svd_full!(A, alg′) + trunc′ = trunc isa TruncationStrategy ? trunc : + trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : + throw(ArgumentError("Unknown truncation strategy: $trunc")) + return truncate!(right_null!, (S, Vᴴ), trunc′) +end diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 4b00ec66..efc678ee 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -69,10 +69,10 @@ See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [ """ function left_orth end function left_orth! end -function left_orth!(A::AbstractMatrix; kwargs...) +function left_orth!(A; kwargs...) return left_orth!(A, initialize_output(left_orth!, A); kwargs...) end -function left_orth(A::AbstractMatrix; kwargs...) +function left_orth(A; kwargs...) return left_orth!(copy_input(left_orth, A); kwargs...) end @@ -128,10 +128,10 @@ See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`r """ function right_orth end function right_orth! end -function right_orth!(A::AbstractMatrix; kwargs...) +function right_orth!(A; kwargs...) return right_orth!(A, initialize_output(right_orth!, A); kwargs...) end -function right_orth(A::AbstractMatrix; kwargs...) +function right_orth(A; kwargs...) return right_orth!(copy_input(right_orth, A); kwargs...) end @@ -180,10 +180,10 @@ See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), [ """ function left_null end function left_null! end -function left_null!(A::AbstractMatrix; kwargs...) +function left_null!(A; kwargs...) return left_null!(A, initialize_output(left_null!, A); kwargs...) end -function left_null(A::AbstractMatrix; kwargs...) +function left_null(A; kwargs...) return left_null!(copy_input(left_null, A); kwargs...) end @@ -230,9 +230,9 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), [`r """ function right_null end function right_null! end -function right_null!(A::AbstractMatrix; kwargs...) +function right_null!(A; kwargs...) return right_null!(A, initialize_output(right_null!, A); kwargs...) end -function right_null(A::AbstractMatrix; kwargs...) +function right_null(A; kwargs...) return right_null!(copy_input(right_null, A); kwargs...) end diff --git a/test/orthnull.jl b/test/orthnull.jl index 0823d3c5..7f497eaf 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -2,8 +2,53 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, I +using LinearAlgebra: LinearAlgebra, I, mul! using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, + initialize_output + +# Used to test non-AbstractMatrix codepaths. +struct LinearMap{P<:AbstractMatrix} + parent::P +end +Base.parent(A::LinearMap) = getfield(A, :parent) +function Base.copy!(dest::LinearMap, src::LinearMap) + copy!(parent(dest), parent(src)) + return dest +end +function LinearAlgebra.mul!(C::LinearMap, A::LinearMap, B::LinearMap) + mul!(parent(C), parent(A), parent(B)) + return C +end + +function MatrixAlgebraKit.copy_input(::typeof(qr_compact), A::LinearMap) + return LinearMap(copy_input(qr_compact, parent(A))) +end +function MatrixAlgebraKit.copy_input(::typeof(lq_compact), A::LinearMap) + return LinearMap(copy_input(lq_compact, parent(A))) +end +function MatrixAlgebraKit.initialize_output(::typeof(left_orth!), A::LinearMap) + return LinearMap.(initialize_output(left_orth!, parent(A))) +end +function MatrixAlgebraKit.initialize_output(::typeof(right_orth!), A::LinearMap) + return LinearMap.(initialize_output(right_orth!, parent(A))) +end +function MatrixAlgebraKit.check_input(::typeof(left_orth!), A::LinearMap, VC) + return check_input(left_orth!, parent(A), parent.(VC)) +end +function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC) + return check_input(right_orth!, parent(A), parent.(VC)) +end +function MatrixAlgebraKit.default_svd_algorithm(A::LinearMap) + return default_svd_algorithm(parent(A)) +end +function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap, + alg::LAPACK_SVDAlgorithm) + return LinearMap.(initialize_output(svd_compact!, parent(A), alg)) +end +function MatrixAlgebraKit.svd_compact!(A::LinearMap, USVᴴ, alg::LAPACK_SVDAlgorithm) + return LinearMap.(svd_compact!(parent(A), parent.(USVᴴ), alg)) +end @testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -23,6 +68,10 @@ using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow @test N' * N ≈ I @test V * V' + N * N' ≈ I + M = LinearMap(A) + VM, CM = @constinferred left_orth(M; kind=:svd) + @test parent(VM) * parent(CM) ≈ A + if m > n nullity = 5 V, C = @constinferred left_orth(A) @@ -162,6 +211,10 @@ end @test Nᴴ * Nᴴ' ≈ I @test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I + M = LinearMap(A) + CM, VMᴴ = @constinferred right_orth(M; kind=:svd) + @test parent(CM) * parent(VMᴴ) ≈ A + Ac = similar(A) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ)