@@ -127,33 +127,65 @@ function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection)
127127 return TruncationIntersection((trunc1, trunc2. components... ))
128128end
129129
130+ """
131+ TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
132+
133+ Generic wrapper type for algorithms that consist of first using `alg`, followed by a
134+ truncation through `trunc`.
135+ """
136+ struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
137+ alg:: A
138+ trunc:: T
139+ end
140+
130141# truncate!
131142# ---------
132143# Generic implementation: `findtruncated` followed by indexing
133144@doc """
134145 truncate!(f, out, strategy::TruncationStrategy)
146+ truncate!(f, out, alg::AbstractAlgorithm)
135147
136148Generic interface for post-truncating a decomposition, specified in `out`.
137149""" truncate!
150+
138151# TODO : should we return a view?
152+ function truncate!(:: typeof (svd_trunc!), USVᴴ, alg:: TruncatedAlgorithm )
153+ return truncate!(svd_trunc!, USVᴴ, alg. trunc)
154+ end
139155function truncate!(:: typeof (svd_trunc!), (U, S, Vᴴ), strategy:: TruncationStrategy )
140156 ind = findtruncated(diagview(S), strategy)
141157 return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
142158end
159+
160+ function truncate!(:: typeof (eig_trunc!), DV, alg:: TruncatedAlgorithm )
161+ return truncate!(eig_trunc!, DV, alg. trunc)
162+ end
143163function truncate!(:: typeof (eig_trunc!), (D, V), strategy:: TruncationStrategy )
144164 ind = findtruncated(diagview(D), strategy)
145165 return Diagonal(diagview(D)[ind]), V[:, ind]
146166end
167+
168+ function truncate!(:: typeof (eigh_trunc!), DV, alg:: TruncatedAlgorithm )
169+ return truncate!(eigh_trunc!, DV, alg. trunc)
170+ end
147171function truncate!(:: typeof (eigh_trunc!), (D, V), strategy:: TruncationStrategy )
148172 ind = findtruncated(diagview(D), strategy)
149173 return Diagonal(diagview(D)[ind]), V[:, ind]
150174end
175+
176+ function truncate!(:: typeof (left_null!), US, alg:: TruncatedAlgorithm )
177+ return truncate!(left_null!, US, alg. trunc)
178+ end
151179function truncate!(:: typeof (left_null!), (U, S), strategy:: TruncationStrategy )
152180 # TODO : avoid allocation?
153181 extended_S = vcat(diagview(S), zeros(eltype(S), max(0 , size(S, 1 ) - size(S, 2 ))))
154182 ind = findtruncated(extended_S, strategy)
155183 return U[:, ind]
156184end
185+
186+ function truncate!(:: typeof (right_null!), SVᴴ, alg:: TruncatedAlgorithm )
187+ return truncate!(right_null!, SVᴴ, alg. trunc)
188+ end
157189function truncate!(:: typeof (right_null!), (S, Vᴴ), strategy:: TruncationStrategy )
158190 # TODO : avoid allocation?
159191 extended_S = vcat(diagview(S), zeros(eltype(S), max(0 , size(S, 2 ) - size(S, 1 ))))
@@ -196,14 +228,3 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
196228 inds = map(Base. Fix1(findtruncated, values), strategy. components)
197229 return intersect(inds... )
198230end
199-
200- """
201- TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
202-
203- Generic wrapper type for algorithms that consist of first using `alg`, followed by a
204- truncation through `trunc`.
205- """
206- struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
207- alg:: A
208- trunc:: T
209- end
0 commit comments