11# Strategies
22# ----------
3- """
4- notrunc()
5- """
6- notrunc () = NoTruncation ()
73
8- # deprecate
4+ # TODO : deprecate
95const TruncationScheme = TruncationStrategy
106
11- # TODO : add this to MatrixAlgebraKit
12- struct TruncationError{T<: Real } <: TruncationStrategy
13- ϵ:: T
14- p:: Real
15- end
16-
17- """
18- truncerr(epsilon, p)
197"""
20- truncerr (epsilon :: Real , p :: Real = 2 ) = TruncationError (epsilon, p )
8+ TruncationSpace(V::ElementarySpace, by::Function, rev::Bool )
219
22- struct TruncationSpace{S<: ElementarySpace } <: TruncationStrategy
10+ Truncation strategy to keep the first values for each sector when sorted according to `by` and `rev`,
11+ such that the resulting vector space is no greater than `V`.
12+
13+ See also [`truncspace`](@ref).
14+ """
15+ struct TruncationSpace{S<: ElementarySpace ,F} <: TruncationStrategy
2316 space:: S
17+ by:: F
18+ rev:: Bool
2419end
2520
2621"""
27- truncspace(space::ElementarySpace)
22+ truncspace(space::ElementarySpace; by=abs, rev::Bool=true )
2823
29- Truncation strategy to keep the first values such that the resulting space is the infimum of
30- the total space and the provided space .
24+ Truncation strategy to keep the first values for each sector when sorted according to `by` and `rev`,
25+ such that the resulting vector space is no greater than `V` .
3126"""
32- truncspace (space:: ElementarySpace ) = TruncationSpace (space)
27+ function truncspace (space:: ElementarySpace ; by= abs, rev:: Bool = true )
28+ isdual (space) && throw (ArgumentError (" resulting vector space is never dual" ))
29+ return TruncationSpace (space, by, rev)
30+ end
3331
34- # Truncation
35- # ----------
32+ # truncate!
33+ # ---------
3634function truncate! (:: typeof (svd_trunc!),
3735 (U, S, Vᴴ):: Tuple{AbstractTensorMap,AbstractTensorMap,AbstractTensorMap} ,
3836 strategy:: TruncationStrategy )
39- strategy == notrunc () && return (U, S, Vᴴ)
40- ind = findtruncated_sorted (diagview (S), strategy)
37+ ind = findtruncated_svd (diagview (S), strategy)
4138 V_truncated = spacetype (S)(c => length (I) for (c, I) in ind)
4239
4340 Ũ = similar (U, codomain (U) ← V_truncated)
6764function truncate! (:: typeof (left_null!),
6865 (U, S):: Tuple{AbstractTensorMap,AbstractTensorMap} ,
6966 strategy:: MatrixAlgebraKit.TruncationStrategy )
70- strategy == notrunc () && return (U, S)
7167 extended_S = SectorDict (c => vcat (diagview (b),
7268 zeros (eltype (b), max (0 , size (b, 2 ) - size (b, 1 ))))
7369 for (c, b) in blocks (S))
@@ -84,7 +80,6 @@ for f! in (:eig_trunc!, :eigh_trunc!)
8480 @eval function truncate! (:: typeof ($ f!),
8581 (D, V):: Tuple{AbstractTensorMap,AbstractTensorMap} ,
8682 strategy:: TruncationStrategy )
87- strategy == notrunc () && return (D, V)
8883 ind = findtruncated (diagview (D), strategy)
8984 V_truncated = spacetype (D)(c => length (I) for (c, I) in ind)
9085
109104# Find truncation
110105# ---------------
111106# auxiliary functions
112- rtol_to_atol (S, p, atol, rtol) = rtol > 0 ? max (atol, _norm (S, p) * rtol) : atol
107+ rtol_to_atol (S, p, atol, rtol) = rtol > 0 ? max (atol, TensorKit . _norm (S, p) * rtol) : atol
113108
114109function _compute_truncerr (Σdata, truncdim, p= 2 )
115110 I = keytype (Σdata)
@@ -138,99 +133,96 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}; by=identity,
138133 end
139134end
140135
141- # implementations
142- function findtruncated_sorted (S:: SectorDict , strategy:: TruncationStrategy )
143- return findtruncated (S, strategy)
136+ # findtruncated
137+ # -------------
138+ # Generic fallback
139+ function findtruncated_svd (values:: SectorDict , strategy:: TruncationStrategy )
140+ return findtruncated (values, strategy)
144141end
145142
146- function findtruncated_sorted (S:: SectorDict , strategy:: TruncationKeepAbove )
147- atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
148- findtrunc = Base. Fix2 (findtruncated_sorted, truncbelow (atol))
149- return SectorDict (c => findtrunc (d) for (c, d) in S)
150- end
151- function findtruncated (S:: SectorDict , strategy:: TruncationKeepAbove )
152- atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
153- findtrunc = Base. Fix2 (findtruncated, truncbelow (atol))
154- return SectorDict (c => findtrunc (d) for (c, d) in Sd)
155- end
156-
157- function findtruncated_sorted (S:: SectorDict , strategy:: TruncationKeepBelow )
158- atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
159- findtrunc = Base. Fix2 (findtruncated_sorted, truncabove (atol))
160- return SectorDict (c => findtrunc (d) for (c, d) in Sd)
161- end
162- function findtruncated (S:: SectorDict , strategy:: TruncationKeepBelow )
163- atol = rtol_to_atol (S, strategy. p, strategy. atol, strategy. rtol)
164- findtrunc = Base. Fix2 (findtruncated, truncabove (atol))
165- return SectorDict (c => findtrunc (d) for (c, d) in Sd)
166- end
167-
168- function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationError )
169- I = keytype (Sd)
170- truncdim = SectorDict {I,Int} (c => length (d) for (c, d) in Sd)
171- while true
172- next = _findnexttruncvalue (Sd, truncdim)
173- isnothing (next) && break
174- σmin, cmin = next
175- truncdim[cmin] -= 1
176- err = _compute_truncerr (Sd, truncdim, strategy. p)
177- if err > strategy. ϵ
178- truncdim[cmin] += 1
179- break
180- end
181- if truncdim[cmin] == 0
182- delete! (truncdim, cmin)
183- end
184- end
185- return SectorDict {I,Base.OneTo{Int}} (c => Base. OneTo (d) for (c, d) in truncdim)
143+ function findtruncated (values:: SectorDict , :: NoTruncation )
144+ return SectorDict (c => Base. OneTo (length (b)) for (c, b) in values)
186145end
187146
188- function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationKeepSorted )
189- return findtruncated (Sd, strategy)
147+ function findtruncated (values:: SectorDict , strategy:: TruncationByOrder )
148+ perms = SectorDict (c => (sortperm (d; strategy. by, strategy. rev)) for (c, d) in values)
149+ values_sorted = SectorDict (c => d[perms[c]] for (c, d) in values)
150+ inds = findtruncated_svd (values_sorted, truncrank (strategy. howmany))
151+ return SectorDict (c => perms[c][I] for (c, I) in inds)
190152end
191- function findtruncated (Sd:: SectorDict , strategy:: TruncationKeepSorted )
192- permutations = SectorDict (c => (sortperm (d; strategy. by, strategy. rev))
193- for (c, d) in Sd)
194- Sd = SectorDict (c => sort (d; strategy. by, strategy. rev) for (c, d) in Sd)
195-
196- I = keytype (Sd)
197- truncdim = SectorDict {I,Int} (c => length (d) for (c, d) in Sd)
153+ function findtruncated_svd (values:: SectorDict , strategy:: TruncationByOrder )
154+ I = keytype (values)
155+ truncdim = SectorDict {I,Int} (c => length (d) for (c, d) in values)
198156 totaldim = sum (dim (c) * d for (c, d) in truncdim; init= 0 )
199157 while true
200- next = _findnexttruncvalue (Sd , truncdim; strategy. by, strategy. rev)
158+ next = _findnexttruncvalue (values , truncdim; strategy. by, strategy. rev)
201159 isnothing (next) && break
202160 _, cmin = next
203161 truncdim[cmin] -= 1
204162 totaldim -= dim (cmin)
205- if truncdim[cmin] == 0
206- delete! (truncdim, cmin)
207- end
163+ truncdim[cmin] == 0 && delete! (truncdim, cmin)
208164 totaldim <= strategy. howmany && break
209165 end
210- return SectorDict (c => permutations[c][ Base. OneTo (d)] for (c, d) in truncdim)
166+ return SectorDict (c => Base. OneTo (d) for (c, d) in truncdim)
211167end
212168
213- function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationSpace )
214- I = keytype (Sd)
215- return SectorDict {I,Base.OneTo{Int}} (c => Base. OneTo (min (length (d),
216- dim (strategy. space, c)))
217- for (c, d) in Sd)
169+ function findtruncated (values:: SectorDict , strategy:: TruncationByFilter )
170+ return SectorDict (c => findall (strategy. filter, d) for (c, d) in values)
171+ end
172+
173+ function findtruncated (values:: SectorDict , strategy:: TruncationByValue )
174+ atol = rtol_to_atol (values, strategy. p, strategy. atol, strategy. rtol)
175+ strategy′ = trunctol (; atol, strategy. by, strategy. keep_below)
176+ return SectorDict (c => findtruncated (d, strategy′) for (c, d) in values)
177+ end
178+ function findtruncated_svd (values:: SectorDict , strategy:: TruncationByValue )
179+ atol = rtol_to_atol (values, strategy. p, strategy. atol, strategy. rtol)
180+ strategy′ = trunctol (; atol, strategy. by, strategy. keep_below)
181+ return SectorDict (c => findtruncated_svd (d, strategy′) for (c, d) in values)
182+ end
183+
184+ function findtruncated (values:: SectorDict , strategy:: TruncationByError )
185+ perms = SectorDict (c => sortperm (d; by= abs, rev= true ) for (c, d) in values)
186+ values_sorted = SectorDict (c => d[perms[c]] for (c, d) in Sd)
187+ inds = findtruncated_svd (values_sorted, truncrank (strategy. howmany))
188+ return SectorDict (c => perms[c][I] for (c, I) in inds)
189+ end
190+ function findtruncated_svd (values:: SectorDict , strategy:: TruncationByError )
191+ I = keytype (values)
192+ truncdim = SectorDict {I,Int} (c => length (d) for (c, d) in values)
193+ by (c, v) = abs (v)^ strategy. p * dim (c)
194+ Nᵖ = sum (((c, v),) -> sum (Base. Fix1 (by, c), v), values)
195+ ϵᵖ = max (strategy. atol^ strategy. p, strategy. rtol^ strategy. p * Nᵖ)
196+ truncerrᵖ = zero (real (scalartype (valtype (values))))
197+ next = _findnexttruncvalue (values, truncdim)
198+ while ! isnothing (next)
199+ σmin, cmin = next
200+ truncerrᵖ += by (cmin, σmin)
201+ truncerrᵖ >= ϵᵖ && break
202+ (truncdim[cmin] -= 1 ) == 0 && delete! (truncdim, cmin)
203+ next = _findnexttruncvalue (values, truncdim)
204+ end
205+ return SectorDict {I,Base.OneTo{Int}} (c => Base. OneTo (d) for (c, d) in truncdim)
218206end
219207
220- function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationKeepFiltered )
221- return SectorDict (c => findtruncated_sorted (d, strategy) for (c, d) in Sd)
208+ function findtruncated (values:: SectorDict , strategy:: TruncationSpace )
209+ blockstrategy (c) = truncrank (dim (strategy. space, c); strategy. by, strategy. rev)
210+ return SectorDict (c => findtruncated (d, blockstrategy (c)) for (c, d) in values)
222211end
223- function findtruncated (Sd:: SectorDict , strategy:: TruncationKeepFiltered )
224- return SectorDict (c => findtruncated (d, strategy) for (c, d) in Sd)
212+ function findtruncated_svd (values:: SectorDict , strategy:: TruncationSpace )
213+ blockstrategy (c) = truncrank (dim (strategy. space, c); strategy. by, strategy. rev)
214+ return SectorDict (c => findtruncated_svd (d, blockstrategy (c)) for (c, d) in values)
225215end
226216
227- function findtruncated_sorted (Sd:: SectorDict , strategy:: TruncationIntersection )
228- inds = map (Base. Fix1 (findtruncated_sorted, Sd), strategy)
229- return SectorDict (c => intersect (map (Base. Fix2 (getindex, c), inds)... )
217+ function findtruncated (values:: SectorDict , strategy:: TruncationIntersection )
218+ inds = map (Base. Fix1 (findtruncated, values), strategy)
219+ return SectorDict (c => mapreduce (Base. Fix2 (getindex, c), _ind_intersect, inds;
220+ init= trues (length (values[c])))
230221 for c in intersect (map (keys, inds)... ))
231222end
232- function findtruncated (Sd:: SectorDict , strategy:: TruncationIntersection )
233- inds = map (Base. Fix1 (findtruncated, Sd), strategy)
234- return SectorDict (c => intersect (map (Base. Fix2 (getindex, c), inds)... )
223+ function findtruncated_svd (Sd:: SectorDict , strategy:: TruncationIntersection )
224+ inds = map (Base. Fix1 (findtruncated_svd, Sd), strategy)
225+ return SectorDict (c => mapreduce (Base. Fix2 (getindex, c), _ind_intersect, inds;
226+ init= trues (length (values[c])))
235227 for c in intersect (map (keys, inds)... ))
236228end
0 commit comments