Skip to content

Commit f387d4b

Browse files
committed
working randomized SVD
1 parent 999e37a commit f387d4b

File tree

5 files changed

+190
-6
lines changed

5 files changed

+190
-6
lines changed

src/PEPSKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Compat
55
using Accessors: @set, @reset
66
using VectorInterface
77
import VectorInterface as VI
8+
import Random
89

910
using TensorKit
1011
using TensorKit: TruncationScheme

src/algorithms/contractions/ctmrg_contractions.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,17 @@ function contract_projectors(U, S, V, Q, Q_next)
251251
return P_left, P_right
252252
end
253253

254+
function contract_projectors(U, S, V, fq::FourQuadrants)
255+
isqS = sdiag_pow(S, -0.5)
256+
# use * to respect fermionic case
257+
p1 = (codomainind(fq), domainind(fq))
258+
p2 = (codomainind(fq), (numout(fq) + 1,))
259+
P_left = tensorcontract(fq.Q3, p1, false, fq.Q4 * V' * isqS, p2, false, p2)
260+
p3 = ((1,), codomainind(fq) .+ 1)
261+
P_right = tensorcontract(isqS * U' * fq.Q1, p3, false, fq.Q2, p1, false, p3)
262+
return P_left, P_right
263+
end
264+
254265
"""
255266
half_infinite_environment(quadrant1, quadrant2)
256267
half_infinite_environment(C_1, C_2, E_1, E_2, E_3, E_4, A_1, A_2)

src/algorithms/ctmrg/projectors.jl

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ Construct the half-infinite projector algorithm based on the following keyword a
8787
- `:notrunc` : No singular values are truncated and the performed SVDs are exact
8888
- `:truncerr` : Additionally supply error threshold `η`; truncate to the maximal virtual dimension of `η`
8989
- `:truncdim` : Additionally supply truncation dimension `η`; truncate such that the 2-norm of the truncated values is smaller than `η`
90-
- `:truncspace` : Additionally supply truncation space `η`; truncate according to the supplied vector space
90+
- `:truncspace` : Additionally supply truncation space `η`; truncate according to the supplied vector space
9191
- `:truncbelow` : Additionally supply singular value cutoff `η`; truncate such that every retained singular value is larger than `η`
9292
* `verbosity::Int=$(Defaults.projector_verbosity)` : Projector output verbosity which can be:
9393
0. Suppress output information
@@ -125,7 +125,7 @@ Construct the full-infinite projector algorithm based on the following keyword a
125125
- `:notrunc` : No singular values are truncated and the performed SVDs are exact
126126
- `:truncerr` : Additionally supply error threshold `η`; truncate to the maximal virtual dimension of `η`
127127
- `:truncdim` : Additionally supply truncation dimension `η`; truncate such that the 2-norm of the truncated values is smaller than `η`
128-
- `:truncspace` : Additionally supply truncation space `η`; truncate according to the supplied vector space
128+
- `:truncspace` : Additionally supply truncation space `η`; truncate according to the supplied vector space
129129
- `:truncbelow` : Additionally supply singular value cutoff `η`; truncate such that every retained singular value is larger than `η`
130130
* `verbosity::Int=$(Defaults.projector_verbosity)` : Projector output verbosity which can be:
131131
0. Suppress output information
@@ -148,7 +148,7 @@ PROJECTOR_SYMBOLS[:fullinfinite] = FullInfiniteProjector
148148
Determine left and right projectors at the bond given determined by the enlarged corners
149149
and the given coordinate using the specified `alg`.
150150
"""
151-
function compute_projector(enlarged_corners, coordinate, alg::HalfInfiniteProjector)
151+
function compute_projector(enlarged_corners, coordinate, last_space, alg::HalfInfiniteProjector)
152152
# SVD half-infinite environment
153153
halfinf = half_infinite_environment(enlarged_corners...)
154154
svd_alg = svd_algorithm(alg, coordinate)
@@ -166,7 +166,7 @@ function compute_projector(enlarged_corners, coordinate, alg::HalfInfiniteProjec
166166
P_left, P_right = contract_projectors(U, S, V, enlarged_corners...)
167167
return (P_left, P_right), (; U, S, V, info...)
168168
end
169-
function compute_projector(enlarged_corners, coordinate, alg::FullInfiniteProjector)
169+
function compute_projector(enlarged_corners, coordinate, last_space, alg::FullInfiniteProjector)
170170
halfinf_left = half_infinite_environment(enlarged_corners[1], enlarged_corners[2])
171171
halfinf_right = half_infinite_environment(enlarged_corners[3], enlarged_corners[4])
172172

@@ -187,3 +187,121 @@ function compute_projector(enlarged_corners, coordinate, alg::FullInfiniteProjec
187187
P_left, P_right = contract_projectors(U, S, V, halfinf_left, halfinf_right)
188188
return (P_left, P_right), (; U, S, V, info...)
189189
end
190+
191+
# ==========================================================================================
192+
# TBD: compute ncv vectors proj.ncv_ratio * old_nvectors?
193+
struct RandomizedProjector{S, T, R} <: ProjectorAlgorithm
194+
svd_alg::S
195+
trscheme::T
196+
rng::R
197+
oversampling::Int64
198+
max_full::Int64
199+
verbosity::Int
200+
end
201+
202+
PROJECTOR_SYMBOLS[:randomized] = RandomizedProjector
203+
204+
svd_algorithm(alg::RandomizedProjector) = alg.svd_alg
205+
206+
#=
207+
# TBD is this needed?
208+
function RandomizedProjector(;
209+
rng = Random.default_rng(),
210+
svd_alg = (;),
211+
trscheme = (;),
212+
oversampling = 10,
213+
max_full = 100,
214+
verbosity = Defaults.projector_verbosity,
215+
)
216+
217+
# parse SVD forward & rrule algorithm
218+
svd_algorithm = _alg_or_nt(SVDAdjoint, svd_alg)
219+
220+
# parse truncation scheme
221+
truncation_scheme = if trscheme isa TruncationScheme
222+
trscheme
223+
elseif trscheme isa NamedTuple
224+
_TruncationScheme(; trscheme...)
225+
else
226+
throw(ArgumentError("unknown trscheme $trscheme"))
227+
end
228+
229+
return RandomizedProjector(
230+
svd_algorithm, truncation_scheme, rng, oversampling, max_full, verbosity
231+
)
232+
end
233+
=#
234+
235+
236+
# needed as default interface in PEPSKit.ProjectorAlgorithm
237+
function RandomizedProjector(svd_algorithm, trscheme, verbosity)
238+
@show "Hi RandomizedProjector"
239+
@show which(
240+
RandomizedProjector, typeof.(
241+
(
242+
svd_algorithm, trscheme, Random.default_rng(), 10, 100, verbosity,
243+
)
244+
)
245+
)
246+
return RandomizedProjector(
247+
svd_algorithm, trscheme, Random.default_rng(), 10, 100, verbosity
248+
)
249+
end
250+
251+
function random_domain(alg::RandomizedProjector, full_space, last_space)
252+
sector_dims = map(sectors(full_space)) do s
253+
if dim(full_space, s) <= alg.max_full
254+
n = dim(full_space, s)
255+
else
256+
n = dim(last_space, s) + alg.oversampling
257+
end
258+
return s => n
259+
end
260+
return Vect[sectortype(last_space)](sector_dims)
261+
end
262+
263+
264+
function randomized_range_finder(A::AbstractTensorMap, alg::RandomizedProjector, randomized_space)
265+
Q = TensorMap{eltype(A)}(undef, domain(A) randomized_space)
266+
foreach(blocks(Q)) do (s, b)
267+
m, n = size(b)
268+
if m <= alg.max_full
269+
b .= LinearAlgebra.I(m)
270+
else
271+
Ω = randn(alg.rng, eltype(A), domain(A) Vect[sectortype(A)](s => n))
272+
Y = A * Ω
273+
Qs, _ = leftorth!(Y)
274+
b .= block(Qs, s)
275+
end
276+
end
277+
return Q
278+
end
279+
280+
# impose full env, could also be defined for half_infinite_environment, little gain
281+
function compute_projector(fq, coordinate, last_space, alg::RandomizedProjector)
282+
full_space = fuse(domain(fq))
283+
randomized_space = random_domain(alg, full_space, last_space)
284+
Q = randomized_range_finder(fq, alg, randomized_space)
285+
B = Q' * fq
286+
287+
svd_alg = svd_algorithm(alg, coordinate)
288+
U′, S, V, info = tsvd!(B, svd_alg; trunc = alg.trscheme)
289+
U = Q * U′
290+
foreach(blocks(S)) do (s, b)
291+
if size(b, 1) == dim(randomized_space, s) && size(b, 1) < dim(full_space, s)
292+
@warn("Sector is too small, kept all computed values: ", s)
293+
end
294+
end
295+
296+
# Check for degenerate singular values; still needed for exact blocks
297+
Zygote.isderiving() && ignore_derivatives() do
298+
if alg.verbosity > 0 && is_degenerate_spectrum(S)
299+
svals = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(S))
300+
@warn("degenerate singular values detected: ", svals)
301+
end
302+
end
303+
304+
@reset info.truncation_error = info.truncation_error / norm(S) # normalize truncation error
305+
P_left, P_right = contract_projectors(U, S, V, fq)
306+
return (P_left, P_right), (; U, S, V, info...)
307+
end

src/algorithms/ctmrg/simultaneous.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ function simultaneous_projectors(
9797
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
9898
alg′ = @set alg.trscheme = trscheme
9999
ec = (enlarged_corners[coordinate...], enlarged_corners[coordinate′...])
100-
return compute_projector(ec, coordinate, alg′)
100+
last_space = space(env.edges[coordinate[1], coordinate′[2:3]...], 1)
101+
return compute_projector(ec, coordinate, last_space, alg′)
101102
end
102103
function simultaneous_projectors(
103104
coordinate, enlarged_corners::Array{E, 3}, env, alg::FullInfiniteProjector
@@ -109,13 +110,35 @@ function simultaneous_projectors(
109110
coordinate2 = _next_coordinate(coordinate, rowsize, colsize)
110111
coordinate3 = _next_coordinate(coordinate2, rowsize, colsize)
111112
coordinate4 = _next_coordinate(coordinate3, rowsize, colsize)
113+
last_space = space(env.edges[coordinate[1], coordinate′[2:3]...], 1)
112114
ec = (
113115
enlarged_corners[coordinate4...],
114116
enlarged_corners[coordinate...],
115117
enlarged_corners[coordinate2...],
116118
enlarged_corners[coordinate3...],
117119
)
118-
return compute_projector(ec, coordinate, alg′)
120+
return compute_projector(ec, coordinate, last_space, alg′)
121+
end
122+
123+
# TBD share code with FullInfiniteProjector?
124+
function simultaneous_projectors(
125+
coordinate, enlarged_corners::Array{E, 3}, env, alg::RandomizedProjector
126+
) where {E}
127+
coordinate′ = _next_coordinate(coordinate, size(env)[2:3]...)
128+
trscheme = truncation_scheme(alg, env.edges[coordinate[1], coordinate′[2:3]...])
129+
alg′ = @set alg.trscheme = trscheme
130+
rowsize, colsize = size(enlarged_corners)[2:3]
131+
coordinate2 = _next_coordinate(coordinate, rowsize, colsize)
132+
coordinate3 = _next_coordinate(coordinate2, rowsize, colsize)
133+
coordinate4 = _next_coordinate(coordinate3, rowsize, colsize)
134+
fq = FourQuadrants(
135+
enlarged_corners[coordinate4...],
136+
enlarged_corners[coordinate...],
137+
enlarged_corners[coordinate2...],
138+
enlarged_corners[coordinate3...],
139+
)
140+
last_space = space(env.edges[coordinate[1], coordinate′[2:3]...], 1)
141+
return compute_projector(fq, coordinate, last_space, alg′)
119142
end
120143

121144
"""

src/environments/ctmrg_environments.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,34 @@ function VI.inner(env₁::CTMRGEnv, env₂::CTMRGEnv)
477477
return inner((env₁.corners, env₁.edges), (env₂.corners, env₂.edges))
478478
end
479479
VI.norm(env::CTMRGEnv) = norm((env.corners, env.edges))
480+
481+
# ==========================================================================================
482+
483+
struct FourQuadrants{S, T, N, TM} <: AbstractTensorMap{S, T, N, N}
484+
Q1::TM
485+
Q2::TM
486+
Q3::TM
487+
Q4::TM
488+
489+
function FourQuadrants{T, S, N, TM}(
490+
Q1, Q2, Q3, Q4
491+
) where {S, T, N, TM <: AbstractTensorMap{T, S, N, N}}
492+
return new{T, S, N, TM}(Q1, Q2, Q3, Q4)
493+
end
494+
end
495+
496+
function FourQuadrants(Q1, Q2, Q3, Q4)
497+
return FourQuadrants{eltype(Q1), spacetype(Q1), numin(Q1), typeof(Q1)}(Q1, Q2, Q3, Q4)
498+
end
499+
500+
TensorKit.TensorMap(fq::FourQuadrants) = fq.Q1 * fq.Q2 * fq.Q3 * fq.Q4
501+
502+
TensorKit.space(fq::FourQuadrants) = codomain(fq.Q1) domain(fq.Q4)
503+
504+
# TBD use tensorcontract to handle fermions?
505+
(fq::FourQuadrants)(m) = fq.Q1 * (fq.Q2 * (fq.Q3 * (fq.Q4 * m)))
506+
507+
Base.:*(m::AbstractTensorMap, fq::FourQuadrants) = m * fq.Q1 * fq.Q2 * fq.Q3 * fq.Q4
508+
Base.:*(fq::FourQuadrants, m::AbstractTensorMap) = fq.Q1 * (fq.Q2 * (fq.Q3 * (fq.Q4 * m)))
509+
510+
Base.adjoint(fq::FourQuadrants) = FourQuadrants(fq.Q4', fq.Q3', fq.Q2', fq.Q1')

0 commit comments

Comments
 (0)