Skip to content

Commit b56a5fb

Browse files
pbrehmerleburgel
andauthored
Use MatrixAlgebraKit's SVD pullbacks (#335)
* Add `FullSVDPullback` and `TruncSVDPullback` wrapping MAK's pullbacks * Add broadening to `eigh` pullbacks * Update defaults and adjust rest of code to name changes * Add links in docstrings * Fix eigh tests and add eigh broadening tests * Make `eigh` default consistent with MAK * Fix SVD derivative unthunks * Fix gauge fixing index ordering * Format * Fix unit cell indices in SVDAdjoint gauge fixing * Fix IterSVD `truncation_indices` return * Increase test coverage and improve broadening tests * Add regression test for CTMRG gradient accuracy * Add `:trunc` description and revert `:sdd` and `:svd` name changes * Apply suggestions --------- Co-authored-by: leburgel <lander.burgelman@gmail.com>
1 parent f90160b commit b56a5fb

File tree

10 files changed

+305
-320
lines changed

10 files changed

+305
-320
lines changed

src/Defaults.jl

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ Module containing default algorithm parameter values and arguments.
99
* `ctmrg_maxiter=$(Defaults.ctmrg_maxiter)` : Maximal number of CTMRG iterations per run.
1010
* `ctmrg_miniter=$(Defaults.ctmrg_miniter)` : Minimal number of CTMRG carried out.
1111
* `ctmrg_alg=:$(Defaults.ctmrg_alg)` : Default CTMRG algorithm variant.
12-
- `:simultaneous`: Simultaneous expansion and renormalization of all sides.
13-
- `:sequential`: Sequential application of left moves and rotations.
12+
- `:simultaneous` : Simultaneous expansion and renormalization of all sides.
13+
- `:sequential` : Sequential application of left moves and rotations.
1414
* `ctmrg_verbosity=$(Defaults.ctmrg_verbosity)` : CTMRG output information verbosity
1515
1616
## SVD forward & reverse
@@ -22,25 +22,40 @@ Module containing default algorithm parameter values and arguments.
2222
- `:truncrank` : Additionally supply truncation dimension `η`; truncate such that the 2-norm of the truncated values is smaller than `η`
2323
- `:truncspace` : Additionally supply truncation space `η`; truncate according to the supplied vector space
2424
- `:trunctol` : Additionally supply singular value cutoff `η`; truncate such that every retained singular value is larger than `η`
25+
* `rrule_degeneracy_atol=$(Defaults.rrule_degeneracy_atol)` : Broadening amplitude which smoothens the divergent term in the retained contributions of an SVD or eigh pullback, in case of (pseudo) degenerate singular values
2526
* `svd_fwd_alg=:$(Defaults.svd_fwd_alg)` : SVD algorithm that is used in the forward pass.
26-
- `:sdd`: MatrixAlgebraKit's `LAPACK_DivideAndConquer`
27-
- `:svd`: MatrixAlgebraKit's `LAPACK_QRIteration`
28-
- `:iterative`: Iterative SVD only computing the specifed number of singular values and vectors, see [`IterSVD`](@ref PEPSKit.IterSVD)
27+
- `:sdd` : MatrixAlgebraKit's `LAPACK_DivideAndConquer`
28+
- `:svd` : MatrixAlgebraKit's `LAPACK_QRIteration`
29+
- `:iterative` : Iterative SVD only computing the specifed number of singular values and vectors, see [`IterSVD`](@ref PEPSKit.IterSVD)
2930
* `svd_rrule_tol=$(Defaults.svd_rrule_tol)` : Accuracy of SVD reverse-rule.
3031
* `svd_rrule_min_krylovdim=$(Defaults.svd_rrule_min_krylovdim)` : Minimal Krylov dimension of the reverse-rule algorithm (if it is a Krylov algorithm).
3132
* `svd_rrule_verbosity=$(Defaults.svd_rrule_verbosity)` : SVD gradient output verbosity.
3233
* `svd_rrule_alg=:$(Defaults.svd_rrule_alg)` : Reverse-rule algorithm for the SVD gradient.
33-
- `:full`: Uses a modified version of MatrixAlgebraKit's reverse-rule for `svd_compact` which doesn't solve any linear problem and instead requires access to the full SVD, see [`PEPSKit.FullSVDReverseRule`](@ref).
34-
- `:gmres`: GMRES iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.GMRES) for details
35-
- `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details
36-
- `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details
37-
* `svd_rrule_broadening=$(Defaults.svd_rrule_broadening)` : Lorentzian broadening amplitude which smoothens the divergent term in the SVD adjoint in case of (pseudo) degenerate singular values
34+
- `:full` : Uses a modified version of MatrixAlgebraKit's reverse-rule for `svd_compact` which doesn't solve any linear problem and instead requires access to the full SVD, see [`PEPSKit.FullSVDPullback`](@ref).
35+
- `:trunc` : MatrixAlgebraKit's `svd_trunc_pullback!` solving a Sylvester equation on the truncated subspace and therefore only requires access to the truncated SVD.
36+
- `:gmres` : GMRES iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.GMRES) for details
37+
- `:bicgstab` : BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details
38+
- `:arnoldi` : Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details
39+
40+
## `eigh` forward & reverse
41+
42+
* `eigh_fwd_alg=:$(Defaults.eigh_fwd_alg)` : `eigh` algorithm that is used in the forward pass.
43+
- `:qriteration` : MatrixAlgebraKit's `LAPACK_QRIteration`.
44+
- `:bisection` : MatrixAlgebraKit's `LAPACK_Bisection`.
45+
- `:divideandconquer` : MatrixAlgebraKit's `LAPACK_DivideAndConquer`.
46+
- `:multiple` : MatrixAlgebraKit's `LAPACK_MultipleRelativelyRobustRepresentations`.
47+
- `:lanczos` : Lanczos algorithm, see [`KrylovKit.Lanczos`](@extref) for details.
48+
- `:blocklanczos` : Block Lanczos algorithm, see [`KrylovKit.BlockLanczos`](@extref) for details.
49+
* `eigh_rrule_alg=:$(Defaults.eigh_rrule_alg)` : Reverse-rule algorithm for the `eigh` gradient.
50+
- `:full` : Full pullback algorithm for eigendecompositions, see [`PEPSKit.FullEighPullback`](@ref).
51+
- `:trunc` : Truncated reverse-mode algorithm for eigendecompositions, see [`PEPSKit.TruncEighPullback`](@ref).
52+
* `eigh_rrule_verbosity=$(Defaults.eigh_rrule_verbosity)` : eigh gradient output verbosity.
3853
3954
## Projectors
4055
4156
* `projector_alg=:$(Defaults.projector_alg)` : Default variant of the CTMRG projector algorithm.
42-
- `:halfinfinite`: Projection via SVDs of half-infinite (two enlarged corners) CTMRG environments.
43-
- `:fullinfinite`: Projection via SVDs of full-infinite (all four enlarged corners) CTMRG environments.
57+
- `:halfinfinite` : Projection via SVDs of half-infinite (two enlarged corners) CTMRG environments.
58+
- `:fullinfinite` : Projection via SVDs of full-infinite (all four enlarged corners) CTMRG environments.
4459
* `projector_verbosity=$(Defaults.projector_verbosity)` : Projector output information verbosity.
4560
4661
## Fixed-point gradient
@@ -93,17 +108,17 @@ const sparse = false # TODO: implement sparse CTMRG
93108

94109
# SVD forward & reverse
95110
const trunc = :fixedspace # ∈ {:fixedspace, :notrunc, :truncerror, :truncspace, :trunctol}
96-
const svd_fwd_alg = :sdd # ∈ {:sdd, :svd, :iterative}
111+
const rrule_degeneracy_atol = 1.0e-13
112+
const svd_fwd_alg = :sdd # ∈ {:sdd, :svd, :bisection, :jacobi, :iterative}
97113
const svd_rrule_tol = ctmrg_tol
98114
const svd_rrule_min_krylovdim = 48
99115
const svd_rrule_verbosity = -1
100-
const svd_rrule_alg = :full # ∈ {:full, :gmres, :bicgstab, :arnoldi}
101-
const svd_rrule_broadening = 1.0e-13
116+
const svd_rrule_alg = :full # ∈ {:full, :trunc, :gmres, :bicgstab, :arnoldi}
102117
const krylovdim_factor = 1.4
103118

104119
# eigh forward & reverse
105120
const eigh_fwd_alg = :qriteration # ∈ {:qriteration, :bisection, :divideandconquer, :multiple, :lanczos, :blocklanczos}
106-
const eigh_rrule_alg = :trunc # ∈ {:trunc, :full}
121+
const eigh_rrule_alg = :full # ∈ {:full, :trunc}
107122
const eigh_rrule_verbosity = 0
108123

109124
# QR forward & reverse

src/PEPSKit.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using MatrixAlgebraKit: LAPACK_DivideAndConquer, LAPACK_QRIteration
1212
using MatrixAlgebraKit:
1313
TruncationStrategy, NoTruncation, truncate, findtruncated, truncation_error, diagview
1414
using MatrixAlgebraKit: LAPACK_EighAlgorithm, eigh_pullback!, eigh_trunc_pullback!
15+
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
1516

1617
using TensorKit
1718
using TensorKit: AdjointTensorMap, SectorDict
@@ -123,11 +124,11 @@ include("algorithms/select_algorithm.jl")
123124

124125
using .Defaults: set_scheduler!
125126
export set_scheduler!
126-
export SVDAdjoint, FullSVDReverseRule, IterSVD
127+
export EighAdjoint, IterEigh, SVDAdjoint, IterSVD, QRAdjoint
127128
export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG
128129
export FixedSpaceTruncation, SiteDependentTruncation
129130
export HalfInfiniteProjector, FullInfiniteProjector
130-
export EighAdjoint, IterEigh, QRAdjoint, C4vCTMRG, C4vEighProjector, C4vQRProjector
131+
export C4vCTMRG, C4vEighProjector, C4vQRProjector
131132
export initialize_random_c4v_env, initialize_singlet_c4v_env
132133
export LocalOperator, physicalspace
133134
export product_peps

src/algorithms/ctmrg/projectors.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656
5757
Return the tensor decomposition algorithm of the `alg` projector algorithm.
5858
Additionally, the multi-index `(dir, r, c)` can be supplied which will return the
59-
decomposition performed at that index, e.g. when using `FixedEigh` or `FixedSVD`.
59+
decomposition performed at that index, e.g. when using [`FixedEig`](@ref) or [`FixedSVD`](@ref).
6060
"""
6161
decomposition_algorithm(alg::ProjectorAlgorithm) = alg.decomposition_alg
6262
function decomposition_algorithm(alg::ProjectorAlgorithm, (dir, r, c))
@@ -67,11 +67,12 @@ function decomposition_algorithm(alg::ProjectorAlgorithm, (dir, r, c))
6767
FixedSVD(
6868
fwd_alg.U[dir, r, c], fwd_alg.S[dir, r, c], fwd_alg.V[dir, r, c],
6969
fwd_alg.U_full[dir, r, c], fwd_alg.S_full[dir, r, c], fwd_alg.V_full[dir, r, c],
70+
fwd_alg.truncation_indices[dir, r, c],
7071
)
7172
else
7273
FixedSVD(
7374
fwd_alg.U[dir, r, c], fwd_alg.S[dir, r, c], fwd_alg.V[dir, r, c],
74-
nothing, nothing, nothing,
75+
nothing, nothing, nothing, nothing,
7576
)
7677
end
7778
return SVDAdjoint(; fwd_alg = fix_svd, rrule_alg = decomposition_alg.rrule_alg)

src/algorithms/ctmrg/simultaneous.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ function _split_proj_and_info(proj_and_info)
6464
U_full = map(x -> x[2].U_full, proj_and_info)
6565
S_full = map(x -> x[2].S_full, proj_and_info)
6666
V_full = map(x -> x[2].V_full, proj_and_info)
67-
info = (; truncation_error, condition_number, U, S, V, U_full, S_full, V_full)
67+
truncation_indices = map(x -> x[2].truncation_indices, proj_and_info)
68+
info = (; truncation_error, condition_number, U, S, V, U_full, S_full, V_full, truncation_indices)
6869
return (P_left, P_right), info
6970
end
7071

src/algorithms/optimization/fixed_point_differentiation.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -320,24 +320,25 @@ end
320320
function gauge_fix(alg::SVDAdjoint, signs, info)
321321
# embed gauge signs in larger space to fix gauge of full U and V on truncated subspace
322322
rowsize, colsize = size(signs, 2), size(signs, 3)
323-
signs_full = map(Iterators.product(1:4, 1:rowsize, 1:colsize)) do (dir, r, c)
324-
σ = signs[dir, r, c]
325-
r_sign, c_sign = if dir == NORTH # take unit cell interdependency of signs into account
326-
r, _prev(c, colsize)
323+
inds = info.truncation_indices
324+
signs_full = map(Iterators.product(1:4, 1:rowsize, 1:colsize)) do (dir, row, col)
325+
σ = signs[dir, row, col]
326+
row_sign, col_sign = if dir == NORTH # take unit cell interdependency of signs into account
327+
row, _prev(col, colsize)
327328
elseif dir == EAST
328-
_prev(r, rowsize), c
329+
_prev(row, rowsize), col
329330
elseif dir == SOUTH
330-
r, _next(c, colsize)
331+
row, _next(col, colsize)
331332
elseif dir == WEST
332-
_next(r, rowsize), c
333+
_next(row, rowsize), col
333334
end
334-
extended_space = domain(info.U_full[dir, r_sign, c_sign]) codomain(info.V_full[dir, r_sign, c_sign])
335-
extended_σ = zeros(scalartype(σ), extended_space)
336-
for (c, b) in blocks(extended_σ)
337-
σc = block(σ, c)
338-
kept_dim = size(σc, 1)
339-
b[diagind(b)] .= one(scalartype(σ)) # put ones on the diagonal
340-
b[1:kept_dim, 1:kept_dim] .= σc # set to σ on kept subspace
335+
336+
ind = inds[dir, row_sign, col_sign]
337+
extended_σ = id(scalartype(σ), domain(info.S_full[dir, row_sign, col_sign]))
338+
for (c, b) in blocks)
339+
I = get(ind, c, nothing)
340+
@assert !isnothing(I)
341+
block(extended_σ, c)[I, I] = b
341342
end
342343
return extended_σ
343344
end
@@ -346,15 +347,15 @@ function gauge_fix(alg::SVDAdjoint, signs, info)
346347
U_fixed, V_fixed = fix_relative_phases(info.U, info.V, signs)
347348
U_full_fixed, V_full_fixed = fix_relative_phases(info.U_full, info.V_full, signs_full)
348349
return SVDAdjoint(;
349-
fwd_alg = FixedSVD(U_fixed, info.S, V_fixed, U_full_fixed, info.S_full, V_full_fixed),
350+
fwd_alg = FixedSVD(U_fixed, info.S, V_fixed, U_full_fixed, info.S_full, V_full_fixed, inds),
350351
rrule_alg = alg.rrule_alg,
351352
)
352353
end
353354
function gauge_fix(alg::SVDAdjoint{F}, signs, info) where {F <: IterSVD}
354355
# fix kept U and V only since iterative SVD doesn't have access to full spectrum
355356
U_fixed, V_fixed = fix_relative_phases(info.U, info.V, signs)
356357
return SVDAdjoint(;
357-
fwd_alg = FixedSVD(U_fixed, info.S, V_fixed, nothing, nothing, nothing),
358+
fwd_alg = FixedSVD(U_fixed, info.S, V_fixed, nothing, nothing, nothing, nothing),
358359
rrule_alg = alg.rrule_alg,
359360
)
360361
end
@@ -374,7 +375,7 @@ function gauge_fix(alg::EighAdjoint, signs, info)
374375
V_fixed = info.V * σ'
375376
V_full_fixed = info.V_full * extended_σ'
376377
return EighAdjoint(;
377-
fwd_alg = FixedEig(info.D, V_fixed, info.D_full, V_full_fixed, info.truncation_indices),
378+
fwd_alg = FixedEig(info.D, V_fixed, info.D_full, V_full_fixed, inds),
378379
rrule_alg = alg.rrule_alg,
379380
)
380381
end

src/utility/eigh.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Construct a `FullEighPullback` algorithm struct from the following keyword argum
1616
* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`.
1717
"""
1818
@kwdef struct FullEighPullback
19+
degeneracy_atol::Real = Defaults.rrule_degeneracy_atol
1920
verbosity::Int = 0
2021
end
2122

@@ -37,6 +38,7 @@ Construct a `TruncEighPullback` algorithm struct from the following keyword argu
3738
* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`.
3839
"""
3940
@kwdef struct TruncEighPullback
41+
degeneracy_atol::Real = Defaults.rrule_degeneracy_atol
4042
verbosity::Int = 0
4143
end
4244

@@ -64,8 +66,8 @@ Construct a `EighAdjoint` algorithm struct based on the following keyword argume
6466
- `:lanczos` : Lanczos algorithm for symmetric/Hermitian matrices, see [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Lanczos)
6567
- `:blocklanczos` : Block version of `:lanczos` for repeated extremal eigenvalues, see [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BlockLanczos)
6668
* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.eigh_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following:
67-
- `:trunc` : MatrixAlgebraKit's `eigh_trunc_pullback` solving a Sylvester equation on the truncated subspace
68-
- `:full` : MatrixAlgebraKit's `eigh_pullback` that requires access to the full spectrum
69+
- `:full` : MatrixAlgebraKit's `eigh_pullback!` that requires access to the full spectrum
70+
- `:trunc` : MatrixAlgebraKit's `eigh_trunc_pullback!` solving a Sylvester equation on the truncated subspace
6971
"""
7072
struct EighAdjoint{F, R}
7173
fwd_alg::F
@@ -77,12 +79,8 @@ const EIGH_FWD_SYMBOLS = IdDict{Symbol, Any}(
7779
:bisection => LAPACK_Bisection,
7880
:divideandconquer => LAPACK_DivideAndConquer,
7981
:multiple => LAPACK_MultipleRelativelyRobustRepresentations,
80-
:lanczos =>
81-
(; tol = 1.0e-14, krylovdim = 30, kwargs...) ->
82-
IterEigh(; alg = Lanczos(; tol, krylovdim), kwargs...),
83-
:blocklanczos =>
84-
(; tol = 1.0e-14, krylovdim = 30, kwargs...) ->
85-
IterEigh(; alg = BlockLanczos(; tol, krylovdim), kwargs...),
82+
:lanczos => (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> IterEigh(; alg = Lanczos(; tol, krylovdim), kwargs...),
83+
:blocklanczos => (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> IterEigh(; alg = BlockLanczos(; tol, krylovdim), kwargs...),
8684
)
8785
const EIGH_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}(
8886
:full => FullEighPullback, :trunc => TruncEighPullback,
@@ -105,6 +103,7 @@ function EighAdjoint(; fwd_alg = (;), rrule_alg = (;))
105103
rrule_algorithm = if rrule_alg isa NamedTuple
106104
rrule_kwargs = (;
107105
alg = Defaults.eigh_rrule_alg,
106+
degeneracy_atol = Defaults.rrule_degeneracy_atol,
108107
verbosity = Defaults.eigh_rrule_verbosity,
109108
rrule_alg...,
110109
) # overwrite with specified kwargs
@@ -113,7 +112,7 @@ function EighAdjoint(; fwd_alg = (;), rrule_alg = (;))
113112
throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)"))
114113
rrule_type = EIGH_RRULE_SYMBOLS[rrule_kwargs.alg]
115114
if rrule_type <: Union{FullEighPullback, TruncEighPullback}
116-
rrule_kwargs = (; rrule_kwargs.verbosity)
115+
rrule_kwargs = (; rrule_kwargs.degeneracy_atol, rrule_kwargs.verbosity)
117116
end
118117

119118
rrule_type(; rrule_kwargs...)
@@ -249,7 +248,10 @@ function _eigh_trunc!(f, alg::IterEigh, trunc::TruncationStrategy)
249248
truncation_error =
250249
trunc isa NoTruncation ? abs(zero(scalartype(f))) : norm(V * D * V' - f)
251250
condition_number = cond(D)
252-
info = (; truncation_error, condition_number, D_full = nothing, V_full = nothing)
251+
info = (;
252+
truncation_error, condition_number, D_full = nothing, V_full = nothing,
253+
truncation_indices = nothing,
254+
)
253255

254256
return D, V, info
255257
end
@@ -329,7 +331,7 @@ function _get_pullback_gauge_tol(verbosity::Int)
329331
if verbosity 0 # never print gauge sensitivity
330332
return (_) -> Inf
331333
elseif verbosity == 1 # print gauge sensitivity above default atol
332-
MatrixAlgebraKit.default_pullback_gaugetol
334+
MatrixAlgebraKit.default_pullback_gauge_atol
333335
else # always print gauge sensitivity
334336
return (_) -> 0.0
335337
end
@@ -349,7 +351,7 @@ function ChainRulesCore.rrule(
349351
function eigh_trunc!_full_pullback(ΔDV)
350352
Δt = eigh_pullback!(
351353
zeros(scalartype(t), space(t)), t, (D, V), ΔDV, inds;
352-
gauge_atol = gtol(ΔDV)
354+
gauge_atol = gtol(ΔDV), degeneracy_atol = alg.rrule_alg.degeneracy_atol,
353355
)
354356
return NoTangent(), Δt, NoTangent()
355357
end
@@ -373,7 +375,7 @@ function ChainRulesCore.rrule(
373375
function eigh_trunc!_trunc_pullback(ΔDV)
374376
Δf = eigh_trunc_pullback!(
375377
zeros(scalartype(t), space(t)), t, (D, V), ΔDV;
376-
gauge_atol = gtol(ΔDV)
378+
gauge_atol = gtol(ΔDV), degeneracy_atol = alg.rrule_alg.degeneracy_atol,
377379
)
378380
return NoTangent(), Δf, NoTangent()
379381
end

0 commit comments

Comments
 (0)