7878function MAK. truncate (
7979 :: typeof (left_null!), (U, S):: NTuple{2, AbstractTensorMap} , strategy:: TruncationStrategy
8080 )
81- extended_S = SectorDict (
82- c => vcat ( diagview (b), zeros ( eltype (b), max ( 0 , size (b, 1 ) - size (b, 2 ))) )
83- for (c, b) in blocks (S)
84- )
81+ extended_S = zerovector! ( SectorVector {eltype(S)} (undef, fuse ( codomain (U))))
82+ for (c, b) in blocks (S )
83+ copyto! (extended_S[c], diagview (b)) # copyto! since `b` might be shorter
84+ end
8585 ind = MAK. findtruncated (extended_S, strategy)
8686 V_truncated = truncate_space (space (S, 1 ), ind)
8787 Ũ = similar (U, codomain (U) ← V_truncated)
9191function MAK. truncate (
9292 :: typeof (right_null!), (S, Vᴴ):: NTuple{2, AbstractTensorMap} , strategy:: TruncationStrategy
9393 )
94- extended_S = SectorDict (
95- c => vcat ( diagview (b), zeros ( eltype (b), max ( 0 , size (b, 2 ) - size (b, 1 ))) )
96- for (c, b) in blocks (S)
97- )
94+ extended_S = zerovector! ( SectorVector {eltype(S)} (undef, fuse ( domain (Vᴴ))))
95+ for (c, b) in blocks (S )
96+ copyto! (extended_S[c], diagview (b)) # copyto! since `b` might be shorter
97+ end
9898 ind = MAK. findtruncated (extended_S, strategy)
9999 V_truncated = truncate_space (dual (space (S, 2 )), ind)
100100 Ṽᴴ = similar (Vᴴ, V_truncated ← domain (Vᴴ))
@@ -177,26 +177,40 @@ function _findnexttruncvalue(
177177 end
178178end
179179
180+ function _sort_and_perm (values:: SectorVector ; by = identity, rev:: Bool = false )
181+ values_sorted = similar (values)
182+ perms = SectorDict (
183+ (
184+ begin
185+ p = sortperm (v; by, rev)
186+ vs = values_sorted[c]
187+ vs .= view (v, p)
188+ c => p
189+ end
190+ ) for (c, v) in pairs (values)
191+ )
192+ return values_sorted, perms
193+ end
194+
180195# findtruncated
181196# -------------
182197# Generic fallback
183- function MAK. findtruncated_svd (values:: SectorDict , strategy:: TruncationStrategy )
198+ function MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationStrategy )
184199 return MAK. findtruncated (values, strategy)
185200end
186201
187- function MAK. findtruncated (values:: SectorDict , :: NoTruncation )
188- return SectorDict (c => Colon () for (c, b) in values)
202+ function MAK. findtruncated (values:: SectorVector , :: NoTruncation )
203+ return SectorDict (c => Colon () for c in keys ( values) )
189204end
190205
191- function MAK. findtruncated (values:: SectorDict , strategy:: TruncationByOrder )
192- perms = SectorDict (c => (sortperm (d; strategy. by, strategy. rev)) for (c, d) in values)
193- values_sorted = SectorDict (c => d[perms[c]] for (c, d) in values)
206+ function MAK. findtruncated (values:: SectorVector , strategy:: TruncationByOrder )
207+ values_sorted, perms = _sort_and_perm (values; strategy. by, strategy. rev)
194208 inds = MAK. findtruncated_svd (values_sorted, truncrank (strategy. howmany))
195209 return SectorDict (c => perms[c][I] for (c, I) in inds)
196210end
197- function MAK. findtruncated_svd (values:: SectorDict , strategy:: TruncationByOrder )
211+ function MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationByOrder )
198212 I = keytype (values)
199- truncdim = SectorDict {I, Int} (c => length (d) for (c, d) in values)
213+ truncdim = SectorDict {I, Int} (c => length (d) for (c, d) in pairs ( values) )
200214 totaldim = sum (dim (c) * d for (c, d) in truncdim; init = 0 )
201215 while totaldim > strategy. howmany
202216 next = _findnexttruncvalue (values, truncdim; strategy. by, strategy. rev)
@@ -209,32 +223,31 @@ function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByOrder)
209223 return SectorDict (c => Base. OneTo (d) for (c, d) in truncdim)
210224end
211225
212- function MAK. findtruncated (values:: SectorDict , strategy:: TruncationByFilter )
213- return SectorDict (c => findall (strategy. filter, d) for (c, d) in values)
226+ function MAK. findtruncated (values:: SectorVector , strategy:: TruncationByFilter )
227+ return SectorDict (c => findall (strategy. filter, d) for (c, d) in pairs ( values) )
214228end
215229
216- function MAK. findtruncated (values:: SectorDict , strategy:: TruncationByValue )
230+ function MAK. findtruncated (values:: SectorVector , strategy:: TruncationByValue )
217231 atol = rtol_to_atol (values, strategy. p, strategy. atol, strategy. rtol)
218232 strategy′ = trunctol (; atol, strategy. by, strategy. keep_below)
219- return SectorDict (c => MAK. findtruncated (d, strategy′) for (c, d) in values)
233+ return SectorDict (c => MAK. findtruncated (d, strategy′) for (c, d) in pairs ( values) )
220234end
221- function MAK. findtruncated_svd (values:: SectorDict , strategy:: TruncationByValue )
235+ function MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationByValue )
222236 atol = rtol_to_atol (values, strategy. p, strategy. atol, strategy. rtol)
223237 strategy′ = trunctol (; atol, strategy. by, strategy. keep_below)
224- return SectorDict (c => MAK. findtruncated_svd (d, strategy′) for (c, d) in values)
238+ return SectorDict (c => MAK. findtruncated_svd (d, strategy′) for (c, d) in pairs ( values) )
225239end
226240
227- function MAK. findtruncated (values:: SectorDict , strategy:: TruncationByError )
228- perms = SectorDict (c => sortperm (d; by = abs, rev = true ) for (c, d) in values)
229- values_sorted = SectorDict (c => d[perms[c]] for (c, d) in Sd)
241+ function MAK. findtruncated (values:: SectorVector , strategy:: TruncationByError )
242+ values_sorted, perms = _sort_and_perm (values; strategy. by, strategy. rev)
230243 inds = MAK. findtruncated_svd (values_sorted, truncrank (strategy. howmany))
231244 return SectorDict (c => perms[c][I] for (c, I) in inds)
232245end
233- function MAK. findtruncated_svd (values:: SectorDict , strategy:: TruncationByError )
246+ function MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationByError )
234247 I = keytype (values)
235- truncdim = SectorDict {I, Int} (c => length (d) for (c, d) in values)
248+ truncdim = SectorDict {I, Int} (c => length (d) for (c, d) in pairs ( values) )
236249 by (c, v) = abs (v)^ strategy. p * dim (c)
237- Nᵖ = sum (((c, v),) -> sum (Base. Fix1 (by, c), v), values)
250+ Nᵖ = sum (((c, v),) -> sum (Base. Fix1 (by, c), v), pairs ( values) )
238251 ϵᵖ = max (strategy. atol^ strategy. p, strategy. rtol^ strategy. p * Nᵖ)
239252 truncerrᵖ = zero (real (scalartype (valtype (values))))
240253 next = _findnexttruncvalue (values, truncdim)
@@ -248,16 +261,16 @@ function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByError)
248261 return SectorDict {I, Base.OneTo{Int}} (c => Base. OneTo (d) for (c, d) in truncdim)
249262end
250263
251- function MAK. findtruncated (values:: SectorDict , strategy:: TruncationSpace )
264+ function MAK. findtruncated (values:: SectorVector , strategy:: TruncationSpace )
252265 blockstrategy (c) = truncrank (dim (strategy. space, c); strategy. by, strategy. rev)
253266 return SectorDict (c => MAK. findtruncated (d, blockstrategy (c)) for (c, d) in values)
254267end
255- function MAK. findtruncated_svd (values:: SectorDict , strategy:: TruncationSpace )
268+ function MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationSpace )
256269 blockstrategy (c) = truncrank (dim (strategy. space, c); strategy. by, strategy. rev)
257- return SectorDict (c => MAK. findtruncated_svd (d, blockstrategy (c)) for (c, d) in values)
270+ return SectorDict (c => MAK. findtruncated_svd (d, blockstrategy (c)) for (c, d) in pairs ( values) )
258271end
259272
260- function MAK. findtruncated (values:: SectorDict , strategy:: TruncationIntersection )
273+ function MAK. findtruncated (values:: SectorVector , strategy:: TruncationIntersection )
261274 inds = map (Base. Fix1 (MAK. findtruncated, values), strategy. components)
262275 return SectorDict (
263276 c => mapreduce (
@@ -266,7 +279,7 @@ function MAK.findtruncated(values::SectorDict, strategy::TruncationIntersection)
266279 ) for c in intersect (map (keys, inds)... )
267280 )
268281end
269- function MAK. findtruncated_svd (values:: SectorDict , strategy:: TruncationIntersection )
282+ function MAK. findtruncated_svd (values:: SectorVector , strategy:: TruncationIntersection )
270283 inds = map (Base. Fix1 (MAK. findtruncated_svd, values), strategy. components)
271284 return SectorDict (
272285 c => mapreduce (
@@ -278,13 +291,12 @@ end
278291
279292# Truncation error
280293# ----------------
281- MAK. truncation_error (values:: SectorDict , ind) =
282- MAK. truncation_error! (SectorDict (c => copy (v) for (c, v) in values), ind)
294+ MAK. truncation_error (values:: SectorVector , ind) = MAK. truncation_error! (copy (values), ind)
283295
284- function MAK. truncation_error! (values:: SectorDict , ind)
296+ function MAK. truncation_error! (values:: SectorVector , ind)
285297 for (c, ind_c) in ind
286298 v = values[c]
287299 v[ind_c] .= zero (eltype (v))
288300 end
289- return TensorKit . _norm (values, 2 , zero ( real ( eltype ( valtype (values)))) )
301+ return norm (values)
290302end
0 commit comments