Skip to content

Commit 7c5419b

Browse files
committed
rework truncation to be MatrixAlgebraKit 0.4 compliant
1 parent e5ee802 commit 7c5419b

File tree

5 files changed

+104
-110
lines changed

5 files changed

+104
-110
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Combinatorics = "1"
3333
FiniteDifferences = "0.12"
3434
LRUCache = "1.0.2"
3535
LinearAlgebra = "1"
36-
MatrixAlgebraKit = "0.3.2"
36+
MatrixAlgebraKit = "0.4.0"
3737
OhMyThreads = "0.8.0"
3838
PackageExtensionCompat = "1"
3939
Random = "1"

src/TensorKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor
9191
export scalar, add!, contract!
9292

9393
# truncation schemes
94-
export notrunc, truncerr, truncrank, truncspace, trunctol
94+
export notrunc, truncrank, trunctol, truncfilter, truncspace, truncerror
9595

9696
# cache management
9797
export empty_globalcaches!

src/tensors/factorizations/factorizations.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ export qr_full!, qr_compact!, qr_null!
1212
export lq_full, lq_compact, lq_null
1313
export lq_full!, lq_compact!, lq_null!
1414
export copy_oftype, permutedcopy_oftype, factorisation_scalartype, one!
15-
export TruncationScheme, notrunc, truncbelow, truncerr, truncdim, truncspace, PolarViaSVD
15+
export TruncationScheme, notrunc, trunctol, truncerror, truncrank, truncspace, truncfilter,
16+
PolarViaSVD
1617

1718
using ..TensorKit
1819
using ..TensorKit: AdjointTensorMap, SectorDict, blocktype, foreachblock, one!
@@ -23,12 +24,13 @@ import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian
2324
using TensorOperations: Index2Tuple
2425

2526
using MatrixAlgebraKit
26-
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, TruncationStrategy,
27-
NoTruncation, TruncationKeepAbove, TruncationKeepBelow,
28-
TruncationIntersection, TruncationKeepFiltered, PolarViaSVD,
29-
LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR,
30-
LAPACK_HouseholderLQ, LAPACK_HouseholderQL, LAPACK_HouseholderRQ,
31-
DiagonalAlgorithm
27+
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm
28+
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue,
29+
TruncationByError, TruncationIntersection, TruncationByFilter,
30+
TruncationByOrder
31+
using MatrixAlgebraKit: PolarViaSVD
32+
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR,
33+
LAPACK_HouseholderLQ, LAPACK_HouseholderQL, LAPACK_HouseholderRQ
3234
import MatrixAlgebraKit: default_algorithm,
3335
copy_input, check_input, initialize_output,
3436
qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!,
@@ -38,7 +40,7 @@ import MatrixAlgebraKit: default_algorithm,
3840
left_polar!, left_orth_polar!, right_polar!, right_orth_polar!,
3941
left_null_svd!, right_null_svd!, left_orth_svd!, right_orth_svd!,
4042
left_orth!, right_orth!, left_null!, right_null!,
41-
truncate!, findtruncated, findtruncated_sorted,
43+
truncate!, findtruncated, findtruncated_svd,
4244
diagview, isisometry
4345

4446
include("utility.jl")
Lines changed: 90 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,40 @@
11
# Strategies
22
# ----------
3-
"""
4-
notrunc()
5-
"""
6-
notrunc() = NoTruncation()
73

8-
# deprecate
4+
# TODO: deprecate
95
const TruncationScheme = TruncationStrategy
106

11-
# TODO: add this to MatrixAlgebraKit
12-
struct TruncationError{T<:Real} <: TruncationStrategy
13-
ϵ::T
14-
p::Real
15-
end
16-
17-
"""
18-
truncerr(epsilon, p)
197
"""
20-
truncerr(epsilon::Real, p::Real=2) = TruncationError(epsilon, p)
8+
TruncationSpace(V::ElementarySpace, by::Function, rev::Bool)
219
22-
struct TruncationSpace{S<:ElementarySpace} <: TruncationStrategy
10+
Truncation strategy to keep the first values for each sector when sorted according to `by` and `rev`,
11+
such that the resulting vector space is no greater than `V`.
12+
13+
See also [`truncspace`](@ref).
14+
"""
15+
struct TruncationSpace{S<:ElementarySpace,F} <: TruncationStrategy
2316
space::S
17+
by::F
18+
rev::Bool
2419
end
2520

2621
"""
27-
truncspace(space::ElementarySpace)
22+
truncspace(space::ElementarySpace; by=abs, rev::Bool=true)
2823
29-
Truncation strategy to keep the first values such that the resulting space is the infimum of
30-
the total space and the provided space.
24+
Truncation strategy to keep the first values for each sector when sorted according to `by` and `rev`,
25+
such that the resulting vector space is no greater than `V`.
3126
"""
32-
truncspace(space::ElementarySpace) = TruncationSpace(space)
27+
function truncspace(space::ElementarySpace; by=abs, rev::Bool=true)
28+
isdual(space) && throw(ArgumentError("resulting vector space is never dual"))
29+
return TruncationSpace(space, by, rev)
30+
end
3331

34-
# Truncation
35-
# ----------
32+
# truncate!
33+
# ---------
3634
function truncate!(::typeof(svd_trunc!),
3735
(U, S, Vᴴ)::Tuple{AbstractTensorMap,AbstractTensorMap,AbstractTensorMap},
3836
strategy::TruncationStrategy)
39-
strategy == notrunc() && return (U, S, Vᴴ)
40-
ind = findtruncated_sorted(diagview(S), strategy)
37+
ind = findtruncated_svd(diagview(S), strategy)
4138
V_truncated = spacetype(S)(c => length(I) for (c, I) in ind)
4239

4340
= similar(U, codomain(U) V_truncated)
@@ -67,7 +64,6 @@ end
6764
function truncate!(::typeof(left_null!),
6865
(U, S)::Tuple{AbstractTensorMap,AbstractTensorMap},
6966
strategy::MatrixAlgebraKit.TruncationStrategy)
70-
strategy == notrunc() && return (U, S)
7167
extended_S = SectorDict(c => vcat(diagview(b),
7268
zeros(eltype(b), max(0, size(b, 2) - size(b, 1))))
7369
for (c, b) in blocks(S))
@@ -84,7 +80,6 @@ for f! in (:eig_trunc!, :eigh_trunc!)
8480
@eval function truncate!(::typeof($f!),
8581
(D, V)::Tuple{AbstractTensorMap,AbstractTensorMap},
8682
strategy::TruncationStrategy)
87-
strategy == notrunc() && return (D, V)
8883
ind = findtruncated(diagview(D), strategy)
8984
V_truncated = spacetype(D)(c => length(I) for (c, I) in ind)
9085

@@ -109,7 +104,7 @@ end
109104
# Find truncation
110105
# ---------------
111106
# auxiliary functions
112-
rtol_to_atol(S, p, atol, rtol) = rtol > 0 ? max(atol, _norm(S, p) * rtol) : atol
107+
rtol_to_atol(S, p, atol, rtol) = rtol > 0 ? max(atol, TensorKit._norm(S, p) * rtol) : atol
113108

114109
function _compute_truncerr(Σdata, truncdim, p=2)
115110
I = keytype(Σdata)
@@ -138,99 +133,96 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}; by=identity,
138133
end
139134
end
140135

141-
# implementations
142-
function findtruncated_sorted(S::SectorDict, strategy::TruncationStrategy)
143-
return findtruncated(S, strategy)
136+
# findtruncated
137+
# -------------
138+
# Generic fallback
139+
function findtruncated_svd(values::SectorDict, strategy::TruncationStrategy)
140+
return findtruncated(values, strategy)
144141
end
145142

146-
function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove)
147-
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
148-
findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol))
149-
return SectorDict(c => findtrunc(d) for (c, d) in S)
150-
end
151-
function findtruncated(S::SectorDict, strategy::TruncationKeepAbove)
152-
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
153-
findtrunc = Base.Fix2(findtruncated, truncbelow(atol))
154-
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
155-
end
156-
157-
function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepBelow)
158-
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
159-
findtrunc = Base.Fix2(findtruncated_sorted, truncabove(atol))
160-
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
161-
end
162-
function findtruncated(S::SectorDict, strategy::TruncationKeepBelow)
163-
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
164-
findtrunc = Base.Fix2(findtruncated, truncabove(atol))
165-
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
166-
end
167-
168-
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError)
169-
I = keytype(Sd)
170-
truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd)
171-
while true
172-
next = _findnexttruncvalue(Sd, truncdim)
173-
isnothing(next) && break
174-
σmin, cmin = next
175-
truncdim[cmin] -= 1
176-
err = _compute_truncerr(Sd, truncdim, strategy.p)
177-
if err > strategy.ϵ
178-
truncdim[cmin] += 1
179-
break
180-
end
181-
if truncdim[cmin] == 0
182-
delete!(truncdim, cmin)
183-
end
184-
end
185-
return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim)
143+
function findtruncated(values::SectorDict, ::NoTruncation)
144+
return SectorDict(c => Base.OneTo(length(b)) for (c, b) in values)
186145
end
187146

188-
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepSorted)
189-
return findtruncated(Sd, strategy)
147+
function findtruncated(values::SectorDict, strategy::TruncationByOrder)
148+
perms = SectorDict(c => (sortperm(d; strategy.by, strategy.rev)) for (c, d) in values)
149+
values_sorted = SectorDict(c => d[perms[c]] for (c, d) in values)
150+
inds = findtruncated_svd(values_sorted, truncrank(strategy.howmany))
151+
return SectorDict(c => perms[c][I] for (c, I) in inds)
190152
end
191-
function findtruncated(Sd::SectorDict, strategy::TruncationKeepSorted)
192-
permutations = SectorDict(c => (sortperm(d; strategy.by, strategy.rev))
193-
for (c, d) in Sd)
194-
Sd = SectorDict(c => sort(d; strategy.by, strategy.rev) for (c, d) in Sd)
195-
196-
I = keytype(Sd)
197-
truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in Sd)
153+
function findtruncated_svd(values::SectorDict, strategy::TruncationByOrder)
154+
I = keytype(values)
155+
truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in values)
198156
totaldim = sum(dim(c) * d for (c, d) in truncdim; init=0)
199157
while true
200-
next = _findnexttruncvalue(Sd, truncdim; strategy.by, strategy.rev)
158+
next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev)
201159
isnothing(next) && break
202160
_, cmin = next
203161
truncdim[cmin] -= 1
204162
totaldim -= dim(cmin)
205-
if truncdim[cmin] == 0
206-
delete!(truncdim, cmin)
207-
end
163+
truncdim[cmin] == 0 && delete!(truncdim, cmin)
208164
totaldim <= strategy.howmany && break
209165
end
210-
return SectorDict(c => permutations[c][Base.OneTo(d)] for (c, d) in truncdim)
166+
return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim)
211167
end
212168

213-
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationSpace)
214-
I = keytype(Sd)
215-
return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(min(length(d),
216-
dim(strategy.space, c)))
217-
for (c, d) in Sd)
169+
function findtruncated(values::SectorDict, strategy::TruncationByFilter)
170+
return SectorDict(c => findall(strategy.filter, d) for (c, d) in values)
171+
end
172+
173+
function findtruncated(values::SectorDict, strategy::TruncationByValue)
174+
atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
175+
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
176+
return SectorDict(c => findtruncated(d, strategy′) for (c, d) in values)
177+
end
178+
function findtruncated_svd(values::SectorDict, strategy::TruncationByValue)
179+
atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
180+
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
181+
return SectorDict(c => findtruncated_svd(d, strategy′) for (c, d) in values)
182+
end
183+
184+
function findtruncated(values::SectorDict, strategy::TruncationByError)
185+
perms = SectorDict(c => sortperm(d; by=abs, rev=true) for (c, d) in values)
186+
values_sorted = SectorDict(c => d[perms[c]] for (c, d) in Sd)
187+
inds = findtruncated_svd(values_sorted, truncrank(strategy.howmany))
188+
return SectorDict(c => perms[c][I] for (c, I) in inds)
189+
end
190+
function findtruncated_svd(values::SectorDict, strategy::TruncationByError)
191+
I = keytype(values)
192+
truncdim = SectorDict{I,Int}(c => length(d) for (c, d) in values)
193+
by(c, v) = abs(v)^strategy.p * dim(c)
194+
Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), values)
195+
ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ)
196+
truncerrᵖ = zero(real(scalartype(valtype(values))))
197+
next = _findnexttruncvalue(values, truncdim)
198+
while !isnothing(next)
199+
σmin, cmin = next
200+
truncerrᵖ += by(cmin, σmin)
201+
truncerrᵖ >= ϵᵖ && break
202+
(truncdim[cmin] -= 1) == 0 && delete!(truncdim, cmin)
203+
next = _findnexttruncvalue(values, truncdim)
204+
end
205+
return SectorDict{I,Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim)
218206
end
219207

220-
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepFiltered)
221-
return SectorDict(c => findtruncated_sorted(d, strategy) for (c, d) in Sd)
208+
function findtruncated(values::SectorDict, strategy::TruncationSpace)
209+
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
210+
return SectorDict(c => findtruncated(d, blockstrategy(c)) for (c, d) in values)
222211
end
223-
function findtruncated(Sd::SectorDict, strategy::TruncationKeepFiltered)
224-
return SectorDict(c => findtruncated(d, strategy) for (c, d) in Sd)
212+
function findtruncated_svd(values::SectorDict, strategy::TruncationSpace)
213+
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
214+
return SectorDict(c => findtruncated_svd(d, blockstrategy(c)) for (c, d) in values)
225215
end
226216

227-
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationIntersection)
228-
inds = map(Base.Fix1(findtruncated_sorted, Sd), strategy)
229-
return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...)
217+
function findtruncated(values::SectorDict, strategy::TruncationIntersection)
218+
inds = map(Base.Fix1(findtruncated, values), strategy)
219+
return SectorDict(c => mapreduce(Base.Fix2(getindex, c), _ind_intersect, inds;
220+
init=trues(length(values[c])))
230221
for c in intersect(map(keys, inds)...))
231222
end
232-
function findtruncated(Sd::SectorDict, strategy::TruncationIntersection)
233-
inds = map(Base.Fix1(findtruncated, Sd), strategy)
234-
return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...)
223+
function findtruncated_svd(Sd::SectorDict, strategy::TruncationIntersection)
224+
inds = map(Base.Fix1(findtruncated_svd, Sd), strategy)
225+
return SectorDict(c => mapreduce(Base.Fix2(getindex, c), _ind_intersect, inds;
226+
init=trues(length(values[c])))
235227
for c in intersect(map(keys, inds)...))
236228
end

test/factorizations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ for V in spacelist
226226
@test dim(domain(S1)) <= trunc.howmany
227227

228228
λ = minimum(minimum, values(LinearAlgebra.diag(S1)))
229-
trunc = trunctol- 10eps(λ))
229+
trunc = trunctol(; atol=λ - 10eps(λ))
230230
U2, S2, Vᴴ2 = @constinferred svd_trunc(t; trunc)
231231
@test t * Vᴴ2' U2 * S2
232232
@test isisometry(U2)
@@ -243,7 +243,7 @@ for V in spacelist
243243
@test isisometry(Vᴴ3; side=:right)
244244
@test space(S3, 1) space(S2, 1)
245245

246-
trunc = truncerr(0.5)
246+
trunc = truncerror(; atol=0.5)
247247
U4, S4, Vᴴ4 = @constinferred svd_trunc(t; trunc)
248248
@test t * Vᴴ4' U4 * S4
249249
@test isisometry(U4)

0 commit comments

Comments
 (0)