Skip to content

Commit cd1a1c3

Browse files
committed
add subspace iterations
1 parent 05d078e commit cd1a1c3

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/algorithms/ctmrg/projectors.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,9 @@ struct RandomizedProjector{S, T, R} <: ProjectorAlgorithm
194194
svd_alg::S
195195
trscheme::T
196196
rng::R
197-
oversampling::Int64
198-
max_full::Int64
197+
oversampling::Int
198+
max_full::Int
199+
n_subspace_iter::Int
199200
verbosity::Int
200201
end
201202

@@ -205,11 +206,13 @@ svd_algorithm(alg::RandomizedProjector) = alg.svd_alg
205206

206207
_default_randomized_oversampling = 10
207208
_default_randomized_max_full = 100
209+
_default_n_subspace_iter = 2
208210

209211
# needed as default interface in PEPSKit.ProjectorAlgorithm
210212
function RandomizedProjector(svd_algorithm, trscheme, verbosity)
211213
return RandomizedProjector(
212-
svd_algorithm, trscheme, Random.default_rng(), _default_randomized_oversampling, _default_randomized_max_full, verbosity
214+
svd_algorithm, trscheme, Random.default_rng(), _default_randomized_oversampling,
215+
_default_randomized_max_full, _default_n_subspace_iter, verbosity
213216
)
214217
end
215218

@@ -229,11 +232,17 @@ end
229232
function randomized_range_finder(A::AbstractTensorMap, alg::RandomizedProjector, randomized_space)
230233
Q = TensorMap{eltype(A)}(undef, domain(A) randomized_space)
231234
foreach(blocks(Q)) do (s, b)
235+
Aad = A'
232236
m, n = size(b)
233237
if m <= alg.max_full
234238
b .= LinearAlgebra.I(m)
235239
else
236240
Ω = randn(alg.rng, eltype(A), domain(A) ← Vect[sectortype(A)](s => n))
241+
for _ in 1:alg.n_subspace_iter
242+
Y = A * Ω
243+
Qs, _ = leftorth!(Y)
244+
Ω, _ = leftorth!(Aad * Qs)
245+
end
237246
Y = A * Ω
238247
Qs, _ = leftorth!(Y)
239248
b .= block(Qs, s)
@@ -248,6 +257,7 @@ function compute_projector(fq, coordinate, last_space, alg::RandomizedProjector)
248257
randomized_space = random_domain(alg, full_space, last_space)
249258
Q = randomized_range_finder(fq, alg, randomized_space)
250259
B = Q' * fq
260+
normalize!(B) # TODO better way?
251261

252262
svd_alg = svd_algorithm(alg, coordinate)
253263
U′, S, V, info = tsvd!(B, svd_alg; trunc = alg.trscheme)

0 commit comments

Comments
 (0)