From 40276bffdd46a16ec13a7d258dd12d4a18f86a85 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 16 May 2025 12:46:51 -0400 Subject: [PATCH] Allow customizing `truncate!` based on the algorithm --- src/implementations/eig.jl | 2 +- src/implementations/eigh.jl | 2 +- src/implementations/orthnull.jl | 4 +-- src/implementations/svd.jl | 2 +- src/implementations/truncation.jl | 43 +++++++++++++++++++++++-------- 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 049438a7..f882a94c 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -84,5 +84,5 @@ end function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) D, V = eig_full!(A, DV, alg.alg) - return truncate!(eig_trunc!, (D, V), alg.trunc) + return truncate!(eig_trunc!, (D, V), alg) end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 53837ae5..a47319f7 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -86,5 +86,5 @@ end function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) - return truncate!(eigh_trunc!, (D, V), alg.trunc) + return truncate!(eigh_trunc!, (D, V), alg) end diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 47626e12..8efe8358 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -180,7 +180,7 @@ function left_null!(A::AbstractMatrix, N; trunc=nothing, trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(left_null!, (U, S), trunc′) + return truncate!(left_null!, (U, S), TruncatedAlgorithm(alg_svd′, trunc′)) else throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) end @@ -207,7 +207,7 @@ function right_null!(A::AbstractMatrix, Nᴴ; trunc=nothing, trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(right_null!, (S, Vᴴ), trunc′) + return truncate!(right_null!, (S, Vᴴ), TruncatedAlgorithm(alg_svd′, trunc′)) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 17c621aa..d9476e92 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -170,5 +170,5 @@ end function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg) - return truncate!(svd_trunc!, USVᴴ′, alg.trunc) + return truncate!(svd_trunc!, USVᴴ′, alg) end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 1898a010..b34963df 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -127,33 +127,65 @@ function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection) return TruncationIntersection((trunc1, trunc2.components...)) end +""" + TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) + +Generic wrapper type for algorithms that consist of first using `alg`, followed by a +truncation through `trunc`. +""" +struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm + alg::A + trunc::T +end + # truncate! # --------- # Generic implementation: `findtruncated` followed by indexing @doc """ truncate!(f, out, strategy::TruncationStrategy) + truncate!(f, out, alg::AbstractAlgorithm) Generic interface for post-truncating a decomposition, specified in `out`. """ truncate! + # TODO: should we return a view? +function truncate!(::typeof(svd_trunc!), USVᴴ, alg::TruncatedAlgorithm) + return truncate!(svd_trunc!, USVᴴ, alg.trunc) +end function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy) ind = findtruncated(diagview(S), strategy) return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :] end + +function truncate!(::typeof(eig_trunc!), DV, alg::TruncatedAlgorithm) + return truncate!(eig_trunc!, DV, alg.trunc) +end function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) return Diagonal(diagview(D)[ind]), V[:, ind] end + +function truncate!(::typeof(eigh_trunc!), DV, alg::TruncatedAlgorithm) + return truncate!(eigh_trunc!, DV, alg.trunc) +end function truncate!(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) return Diagonal(diagview(D)[ind]), V[:, ind] end + +function truncate!(::typeof(left_null!), US, alg::TruncatedAlgorithm) + return truncate!(left_null!, US, alg.trunc) +end function truncate!(::typeof(left_null!), (U, S), strategy::TruncationStrategy) # TODO: avoid allocation? extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2)))) ind = findtruncated(extended_S, strategy) return U[:, ind] end + +function truncate!(::typeof(right_null!), SVᴴ, alg::TruncatedAlgorithm) + return truncate!(right_null!, SVᴴ, alg.trunc) +end function truncate!(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy) # TODO: avoid allocation? extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1)))) @@ -196,14 +228,3 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection) inds = map(Base.Fix1(findtruncated, values), strategy.components) return intersect(inds...) end - -""" - TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) - -Generic wrapper type for algorithms that consist of first using `alg`, followed by a -truncation through `trunc`. -""" -struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm - alg::A - trunc::T -end