Skip to content

Commit 57ea439

Browse files
committed
Implement truncated eigenvalues
1 parent 648a34a commit 57ea439

File tree

4 files changed

+131
-66
lines changed

4 files changed

+131
-66
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,38 @@
11
# Factorizations rules
22
# --------------------
3-
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
4-
trunc::TruncationStrategy=TensorKit.notrunc(),
5-
kwargs...)
6-
# TODO: I think we can use tsvd! here without issues because we don't actually require
7-
# the data of `t` anymore.
8-
USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), kwargs...)
9-
10-
if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
11-
USVᴴ′ = MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, trunc)
12-
else
13-
USVᴴ′ = USVᴴ
14-
end
3+
for f in (:tsvd, :eig, :eigh)
4+
f! = Symbol(f, :!)
5+
f_trunc! = f == :tsvd ? :svd_trunc! : Symbol(f, :_trunc!)
6+
f_pullback = Symbol(f, :_pullback)
7+
f_pullback! = f == :tsvd ? :svd_compact_pullback! : Symbol(f, :_full_pullback!)
8+
@eval function ChainRulesCore.rrule(::typeof(TensorKit.$f!), t::AbstractTensorMap;
9+
trunc::TruncationStrategy=TensorKit.notrunc(),
10+
kwargs...)
11+
# TODO: I think we can use f! here without issues because we don't actually require
12+
# the data of `t` anymore.
13+
F = $f(t; trunc=TensorKit.notrunc(), kwargs...)
14+
15+
if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
16+
F′ = MatrixAlgebraKit.truncate!($f_trunc!, F, trunc)
17+
else
18+
F′ = F
19+
end
1520

16-
function tsvd!_pullback(ΔUSVᴴ′)
17-
ΔUSVᴴ = unthunk.(ΔUSVᴴ′)
18-
Δt = zerovector(t)
19-
foreachblock(Δt) do c, (b,)
20-
USVᴴc = block.(USVᴴ, Ref(c))
21-
ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c))
22-
svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc)
23-
return nothing
21+
function $f_pullback(ΔF′)
22+
ΔF = unthunk.(ΔF′)
23+
Δt = zerovector(t)
24+
foreachblock(Δt) do c, (b,)
25+
Fc = block.(F, Ref(c))
26+
ΔFc = block.(ΔF, Ref(c))
27+
$f_pullback!(b, Fc, ΔFc)
28+
return nothing
29+
end
30+
return NoTangent(), Δt
2431
end
25-
return NoTangent(), Δt
26-
end
27-
tsvd!_pullback(::NTuple{3,ZeroTangent}) = NoTangent(), ZeroTangent()
32+
$f_pullback(::Tuple{ZeroTangent,Vararg{ZeroTangent}}) = NoTangent(), ZeroTangent()
2833

29-
return USVᴴ′, tsvd!_pullback
34+
return F′, $f_pullback
35+
end
3036
end
3137

3238
function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTensorMap)
@@ -43,44 +49,6 @@ function ChainRulesCore.rrule(::typeof(LinearAlgebra.svdvals!), t::AbstractTenso
4349
return s, svdvals_pullback
4450
end
4551

46-
function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
47-
DV = eig(t; kwargs...)
48-
49-
function eig!_pullback(ΔDV′)
50-
ΔDV = unthunk.(ΔDV′)
51-
Δt = zerovector(t)
52-
foreachblock(Δt) do c, (b,)
53-
DVc = block.(DV, Ref(c))
54-
ΔDVc = block.(ΔDV, Ref(c))
55-
eig_full_pullback!(b, DVc, ΔDVc)
56-
return nothing
57-
end
58-
return NoTangent(), Δt
59-
end
60-
eig!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
61-
62-
return DV, eig!_pullback
63-
end
64-
65-
function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...)
66-
DV = eigh(t; kwargs...)
67-
68-
function eigh!_pullback(ΔDV′)
69-
ΔDV = unthunk.(ΔDV′)
70-
Δt = zerovector(t)
71-
foreachblock(Δt) do c, (b,)
72-
DVc = block.(DV, Ref(c))
73-
ΔDVc = block.(ΔDV, Ref(c))
74-
eigh_full_pullback!(b, DVc, ΔDVc)
75-
return nothing
76-
end
77-
return NoTangent(), Δt
78-
end
79-
eigh!_pullback(::NTuple{2,ZeroTangent}) = NoTangent(), ZeroTangent()
80-
81-
return DV, eigh!_pullback
82-
end
83-
8452
function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTensorMap;
8553
sortby=nothing, kwargs...)
8654
@assert sortby === nothing "only `sortby=nothing` is supported"

src/tensors/factorizations/implementations.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,27 @@ rightpolar!(t::AbstractTensorMap; kwargs...) = right_polar!(t; kwargs...)
149149

150150
# Eigenvalue decomposition
151151
# ------------------------
152-
eigh!(t::AbstractTensorMap) = eigh_full!(t)
153-
eig!(t::AbstractTensorMap) = eig_full!(t)
154-
eigen!(t::AbstractTensorMap) = ishermitian(t) ? eigh!(t) : eig!(t)
152+
function eigh!(t::AbstractTensorMap; trunc=notrunc(), kwargs...)
153+
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!)
154+
if trunc == notrunc()
155+
return eigh_full!(t; kwargs...)
156+
else
157+
return eigh_trunc!(t; trunc, kwargs...)
158+
end
159+
end
160+
161+
function eig!(t::AbstractTensorMap; trunc=notrunc(), kwargs...)
162+
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eig!)
163+
if trunc == notrunc()
164+
return eig_full!(t; kwargs...)
165+
else
166+
return eig_trunc!(t; trunc, kwargs...)
167+
end
168+
end
169+
170+
function eigen!(t::AbstractTensorMap; kwargs...)
171+
return ishermitian(t) ? eigh!(t; kwargs...) : eig!(t; kwargs...)
172+
end
155173

156174
# Singular value decomposition
157175
# ----------------------------

src/tensors/factorizations/matrixalgebrakit.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,26 @@ function initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::Abstract
193193
return D, V
194194
end
195195

196+
function initialize_output(::typeof(eigh_trunc!), t::AbstractTensorMap,
197+
alg::TruncatedAlgorithm)
198+
return initialize_output(eigh_full!, t, alg.alg)
199+
end
200+
201+
function initialize_output(::typeof(eig_trunc!), t::AbstractTensorMap,
202+
alg::TruncatedAlgorithm)
203+
return initialize_output(eig_full!, t, alg.alg)
204+
end
205+
206+
function eigh_trunc!(t::AbstractTensorMap, DV, alg::TruncatedAlgorithm)
207+
DV′ = eigh_full!(t, DV, alg.alg)
208+
return truncate!(eigh_trunc!, DV′, alg.trunc)
209+
end
210+
211+
function eig_trunc!(t::AbstractTensorMap, DV, alg::TruncatedAlgorithm)
212+
DV′ = eig_full!(t, DV, alg.alg)
213+
return truncate!(eig_trunc!, DV′, alg.trunc)
214+
end
215+
196216
# QR decomposition
197217
# ----------------
198218
const _T_QR = Tuple{<:AbstractTensorMap,<:AbstractTensorMap}

src/tensors/factorizations/truncation.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,47 @@ function truncate!(::typeof(left_null!),
6565
return
6666
end
6767

68+
function truncate!(::typeof(eigh_trunc!), (D, V)::_T_DV, strategy::TruncationStrategy)
69+
ind = findtruncated(diagview(D), strategy)
70+
V_truncated = spacetype(D)(c => length(I) for (c, I) in ind)
71+
72+
= DiagonalTensorMap{scalartype(D)}(undef, V_truncated)
73+
for (c, b) in blocks(D̃)
74+
I = get(ind, c, nothing)
75+
@assert !isnothing(I)
76+
copy!(b.diag, @view(block(D, c).diag[I]))
77+
end
78+
79+
= similar(V, V_truncated domain(V))
80+
for (c, b) in blocks(Ṽ)
81+
I = get(ind, c, nothing)
82+
@assert !isnothing(I)
83+
copy!(b, @view(block(V, c)[I, :]))
84+
end
85+
86+
return D̃, Ṽ
87+
end
88+
function truncate!(::typeof(eig_trunc!), (D, V)::_T_DV, strategy::TruncationStrategy)
89+
ind = findtruncated(diagview(D), strategy)
90+
V_truncated = spacetype(D)(c => length(I) for (c, I) in ind)
91+
92+
= DiagonalTensorMap{scalartype(D)}(undef, V_truncated)
93+
for (c, b) in blocks(D̃)
94+
I = get(ind, c, nothing)
95+
@assert !isnothing(I)
96+
copy!(b.diag, @view(block(D, c).diag[I]))
97+
end
98+
99+
= similar(V, V_truncated domain(V))
100+
for (c, b) in blocks(Ṽ)
101+
I = get(ind, c, nothing)
102+
@assert !isnothing(I)
103+
copy!(b, @view(block(V, c)[I, :]))
104+
end
105+
106+
return D̃, Ṽ
107+
end
108+
68109
# Find truncation
69110
# ---------------
70111
# auxiliary functions
@@ -88,18 +129,28 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}) where {I<:Sector}
88129
return σmin, keys(truncdim)[imin]
89130
end
90131

91-
# sorted implementations
132+
# implementations
92133
function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove)
93134
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
94135
findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol))
95136
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
96137
end
138+
function findtruncated(S::SectorDict, strategy::TruncationKeepAbove)
139+
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
140+
findtrunc = Base.Fix2(findtruncated, truncbelow(atol))
141+
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
142+
end
97143

98144
function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepBelow)
99145
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
100146
findtrunc = Base.Fix2(findtruncated_sorted, truncabove(atol))
101147
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
102148
end
149+
function findtruncated(S::SectorDict, strategy::TruncationKeepBelow)
150+
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
151+
findtrunc = Base.Fix2(findtruncated, truncabove(atol))
152+
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
153+
end
103154

104155
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationError)
105156
I = keytype(Sd)
@@ -153,9 +204,17 @@ end
153204
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationKeepFiltered)
154205
return SectorDict(c => findtruncated_sorted(d, strategy) for (c, d) in Sd)
155206
end
207+
function findtruncated(Sd::SectorDict, strategy::TruncationKeepFiltered)
208+
return SectorDict(c => findtruncated(d, strategy) for (c, d) in Sd)
209+
end
156210

157211
function findtruncated_sorted(Sd::SectorDict, strategy::TruncationIntersection)
158212
inds = map(Base.Fix1(findtruncated_sorted, Sd), strategy)
159213
return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...)
160214
for c in intersect(map(keys, inds)...))
161215
end
216+
function findtruncated(Sd::SectorDict, strategy::TruncationIntersection)
217+
inds = map(Base.Fix1(findtruncated, Sd), strategy)
218+
return SectorDict(c => intersect(map(Base.Fix2(getindex, c), inds)...)
219+
for c in intersect(map(keys, inds)...))
220+
end

0 commit comments

Comments
 (0)