Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ end

function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
D, V = eig_full!(A, DV, alg.alg)
return truncate!(eig_trunc!, (D, V), alg.trunc)
return truncate!(eig_trunc!, (D, V), alg)
end
2 changes: 1 addition & 1 deletion src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ end

function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
D, V = eigh_full!(A, DV, alg.alg)
return truncate!(eigh_trunc!, (D, V), alg.trunc)
return truncate!(eigh_trunc!, (D, V), alg)
end
4 changes: 2 additions & 2 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ function left_null_svd!(A, N, alg, trunc)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(left_null!, (U, S), trunc′)
return truncate!(left_null!, (U, S), TruncatedAlgorithm(alg′, trunc′))
end

function right_null!(A, Nᴴ; trunc=nothing,
Expand Down Expand Up @@ -266,5 +266,5 @@ function right_null_svd!(A, Nᴴ, alg, trunc)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(right_null!, (S, Vᴴ), trunc′)
return truncate!(right_null!, (S, Vᴴ), TruncatedAlgorithm(alg′, trunc′))
end
2 changes: 1 addition & 1 deletion src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,5 @@ end

function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
return truncate!(svd_trunc!, USVᴴ′, alg.trunc)
return truncate!(svd_trunc!, USVᴴ′, alg)
end
43 changes: 32 additions & 11 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,33 +127,65 @@ function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection)
return TruncationIntersection((trunc1, trunc2.components...))
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Generic wrapper type for algorithms that consist of first using `alg`, followed by a
truncation through `trunc`.
"""
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
alg::A
trunc::T
end

# truncate!
# ---------
# Generic implementation: `findtruncated` followed by indexing
@doc """
truncate!(f, out, strategy::TruncationStrategy)
truncate!(f, out, alg::AbstractAlgorithm)

Generic interface for post-truncating a decomposition, specified in `out`.
""" truncate!

# TODO: should we return a view?
function truncate!(::typeof(svd_trunc!), USVᴴ, alg::TruncatedAlgorithm)
return truncate!(svd_trunc!, USVᴴ, alg.trunc)
end
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated(diagview(S), strategy)
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
end

function truncate!(::typeof(eig_trunc!), DV, alg::TruncatedAlgorithm)
return truncate!(eig_trunc!, DV, alg.trunc)
end
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return Diagonal(diagview(D)[ind]), V[:, ind]
end

function truncate!(::typeof(eigh_trunc!), DV, alg::TruncatedAlgorithm)
return truncate!(eigh_trunc!, DV, alg.trunc)
end
function truncate!(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return Diagonal(diagview(D)[ind]), V[:, ind]
end

function truncate!(::typeof(left_null!), US, alg::TruncatedAlgorithm)
return truncate!(left_null!, US, alg.trunc)
end
function truncate!(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
ind = findtruncated(extended_S, strategy)
return U[:, ind]
end

function truncate!(::typeof(right_null!), SVᴴ, alg::TruncatedAlgorithm)
return truncate!(right_null!, SVᴴ, alg.trunc)
end
function truncate!(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1))))
Expand Down Expand Up @@ -196,14 +228,3 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(findtruncated, values), strategy.components)
return intersect(inds...)
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Generic wrapper type for algorithms that consist of first using `alg`, followed by a
truncation through `trunc`.
"""
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
alg::A
trunc::T
end
Loading