diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 049438a7..f882a94c 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -84,5 +84,5 @@ end function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) D, V = eig_full!(A, DV, alg.alg) - return truncate!(eig_trunc!, (D, V), alg.trunc) + return truncate!(eig_trunc!, (D, V), alg) end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 53837ae5..a47319f7 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -86,5 +86,5 @@ end function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) - return truncate!(eigh_trunc!, (D, V), alg.trunc) + return truncate!(eigh_trunc!, (D, V), alg) end diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index f53055db..44bed5e7 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -232,7 +232,7 @@ function left_null_svd!(A, N, alg, trunc) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(left_null!, (U, S), trunc′) + return truncate!(left_null!, (U, S), TruncatedAlgorithm(alg′, trunc′)) end function right_null!(A, Nᴴ; trunc=nothing, @@ -266,5 +266,5 @@ function right_null_svd!(A, Nᴴ, alg, trunc) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : throw(ArgumentError("Unknown truncation strategy: $trunc")) - return truncate!(right_null!, (S, Vᴴ), trunc′) + return truncate!(right_null!, (S, Vᴴ), TruncatedAlgorithm(alg′, trunc′)) end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 17c621aa..d9476e92 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -170,5 +170,5 @@ end function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg) - return truncate!(svd_trunc!, USVᴴ′, alg.trunc) + return truncate!(svd_trunc!, USVᴴ′, alg) end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 1898a010..b34963df 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -127,33 +127,65 @@ function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection) return TruncationIntersection((trunc1, trunc2.components...)) end +""" + TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) + +Generic wrapper type for algorithms that consist of first using `alg`, followed by a +truncation through `trunc`. +""" +struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm + alg::A + trunc::T +end + # truncate! # --------- # Generic implementation: `findtruncated` followed by indexing @doc """ truncate!(f, out, strategy::TruncationStrategy) + truncate!(f, out, alg::AbstractAlgorithm) Generic interface for post-truncating a decomposition, specified in `out`. """ truncate! + # TODO: should we return a view? +function truncate!(::typeof(svd_trunc!), USVᴴ, alg::TruncatedAlgorithm) + return truncate!(svd_trunc!, USVᴴ, alg.trunc) +end function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy) ind = findtruncated(diagview(S), strategy) return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :] end + +function truncate!(::typeof(eig_trunc!), DV, alg::TruncatedAlgorithm) + return truncate!(eig_trunc!, DV, alg.trunc) +end function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) return Diagonal(diagview(D)[ind]), V[:, ind] end + +function truncate!(::typeof(eigh_trunc!), DV, alg::TruncatedAlgorithm) + return truncate!(eigh_trunc!, DV, alg.trunc) +end function truncate!(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) return Diagonal(diagview(D)[ind]), V[:, ind] end + +function truncate!(::typeof(left_null!), US, alg::TruncatedAlgorithm) + return truncate!(left_null!, US, alg.trunc) +end function truncate!(::typeof(left_null!), (U, S), strategy::TruncationStrategy) # TODO: avoid allocation? extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2)))) ind = findtruncated(extended_S, strategy) return U[:, ind] end + +function truncate!(::typeof(right_null!), SVᴴ, alg::TruncatedAlgorithm) + return truncate!(right_null!, SVᴴ, alg.trunc) +end function truncate!(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy) # TODO: avoid allocation? extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1)))) @@ -196,14 +228,3 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection) inds = map(Base.Fix1(findtruncated, values), strategy.components) return intersect(inds...) end - -""" - TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) - -Generic wrapper type for algorithms that consist of first using `alg`, followed by a -truncation through `trunc`. -""" -struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm - alg::A - trunc::T -end