Skip to content

Commit 3907a9f

Browse files
committed
update approximate
utility mps copy remove superfluous `contractcheck`
1 parent 0a0c3f9 commit 3907a9f

File tree

9 files changed

+171
-142
lines changed

9 files changed

+171
-142
lines changed

src/algorithms/approximate/approximate.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,22 @@ function approximate(ψ::InfiniteMPS,
3838
ψ = convert(InfiniteMPS, multi)
3939
return ψ, envs
4040
end
41+
42+
# dispatch to in-place method
43+
function approximate(ψ, toapprox, alg::Union{DMRG,DMRG2,IDMRG1,IDMRG2},
44+
envs=environments(ψ, toapprox))
45+
return approximate!(copy(ψ), toapprox, alg, envs)
46+
end
47+
48+
# disambiguate
49+
function approximate::InfiniteMPS,
50+
toapprox::Tuple{<:InfiniteMPO,<:InfiniteMPS},
51+
algorithm::Union{IDMRG1,IDMRG2},
52+
envs=environments(ψ, toapprox))
53+
envs′ = Multiline([envs])
54+
multi, envs = approximate(convert(MultilineMPS, ψ),
55+
(convert(MultilineMPO, toapprox[1]),
56+
convert(MultilineMPS, toapprox[2])), algorithm, envs′)
57+
ψ = convert(InfiniteMPS, multi)
58+
return ψ, envs
59+
end

src/algorithms/approximate/fvomps.jl

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
1-
# dispatch to in-place method
2-
function approximate(ψ, toapprox, alg::Union{DMRG,DMRG2}, envs...)
3-
return approximate!(copy(ψ), toapprox, alg, envs...)
4-
end
5-
6-
function approximate!::AbstractFiniteMPS, sq, alg, envs=environments(ψ, sq))
7-
tor = approximate!(ψ, [sq], alg, [envs])
8-
return (tor[1], tor[2][1], tor[3])
9-
end
10-
11-
function approximate!::AbstractFiniteMPS, squash::Vector, alg::DMRG2,
12-
envs=[environments(ψ, sq) for sq in squash])
1+
function approximate!::AbstractFiniteMPS, Oϕ, alg::DMRG2,
2+
envs=environments(ψ, Oϕ))
133
ϵ::Float64 = 2 * alg.tol
144
log = IterLog("DMRG2")
155

@@ -18,9 +8,7 @@ function approximate!(ψ::AbstractFiniteMPS, squash::Vector, alg::DMRG2,
188
for iter in 1:(alg.maxiter)
199
ϵ = 0.0
2010
for pos in [1:(length(ψ) - 1); (length(ψ) - 2):-1:1]
21-
AC2′ = sum(zip(squash, envs)) do (sq, pr)
22-
return ac2_proj(pos, ψ, pr)
23-
end
11+
AC2′ = ac2_proj(pos, ψ, Oϕ, envs)
2412
al, c, ar, = tsvd!(AC2′; trunc=alg.trscheme)
2513

2614
AC2 = ψ.AC[pos] * _transpose_tail.AR[pos + 1])
@@ -31,7 +19,7 @@ function approximate!(ψ::AbstractFiniteMPS, squash::Vector, alg::DMRG2,
3119
end
3220

3321
# finalize
34-
ψ, envs = alg.finalize(iter, ψ, squash, envs)::Tuple{typeof(ψ),typeof(envs)}
22+
ψ, envs = alg.finalize(iter, ψ, , envs)::Tuple{typeof(ψ),typeof(envs)}
3523

3624
if ϵ < alg.tol
3725
@infov 2 logfinish!(log, iter, ϵ)
@@ -48,8 +36,7 @@ function approximate!(ψ::AbstractFiniteMPS, squash::Vector, alg::DMRG2,
4836
return ψ, envs, ϵ
4937
end
5038

51-
function approximate!::AbstractFiniteMPS, squash::Vector, alg::DMRG,
52-
envs=[environments(ψ, sq) for sq in squash])
39+
function approximate!::AbstractFiniteMPS, Oϕ, alg::DMRG, envs=environments(ψ, Oϕ))
5340
ϵ::Float64 = 2 * alg.tol
5441
log = IterLog("DMRG")
5542

@@ -58,18 +45,15 @@ function approximate!(ψ::AbstractFiniteMPS, squash::Vector, alg::DMRG,
5845
for iter in 1:(alg.maxiter)
5946
ϵ = 0.0
6047
for pos in [1:(length(ψ) - 1); length(ψ):-1:2]
61-
AC′ = sum(zip(squash, envs)) do (sq, pr)
62-
return ac_proj(pos, ψ, pr)
63-
end
64-
48+
AC′ = ac_proj(pos, ψ, Oϕ, envs)
6549
AC = ψ.AC[pos]
6650
ϵ = max(ϵ, norm(AC′ - AC) / norm(AC′))
6751

6852
ψ.AC[pos] = AC′
6953
end
7054

7155
# finalize
72-
ψ, envs = alg.finalize(iter, ψ, squash, envs)::Tuple{typeof(ψ),typeof(envs)}
56+
ψ, envs = alg.finalize(iter, ψ, , envs)::Tuple{typeof(ψ),typeof(envs)}
7357

7458
if ϵ < alg.tol
7559
@infov 2 logfinish!(log, iter, ϵ)
Lines changed: 64 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
function approximate(ost::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
2-
alg::IDMRG1, oenvs=environments(ost, toapprox))
3-
ψ = copy(ost)
4-
mpo, above = toapprox
5-
envs = IDMRGEnvironments(ost, oenvs)
1+
function approximate!::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
2+
alg::IDMRG1, envs=environments(ψ, toapprox))
63
log = IterLog("IDMRG")
74
ϵ::Float64 = 2 * alg.tol
85

@@ -12,31 +9,27 @@ function approximate(ost::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:Multili
129
C_current = ψ.C[:, 0]
1310

1411
# left to right sweep
15-
for col in 1:size(ψ, 2), row in 1:size(ψ, 1)
16-
h = MPO_∂∂AC(mpo[row, col], leftenv(envs, row, col),
17-
rightenv(envs, row, col))
18-
ψ.AC[row + 1, col] = h * above.AC[row, col]
19-
normalize!.AC[row + 1, col])
20-
21-
ψ.AL[row + 1, col], ψ.C[row + 1, col] = leftorth.AC[row + 1, col])
22-
23-
tm = TransferMatrix(above.AL[row, col], mpo[row, col], ψ.AL[row + 1, col])
24-
setleftenv!(envs, row, col + 1, normalize(leftenv(envs, row, col) * tm))
12+
for col in 1:size(ψ, 2)
13+
for row in 1:size(ψ, 1)
14+
ψ.AC[row + 1, col] = ac_proj(row, col, ψ, toapprox, envs)
15+
normalize!.AC[row + 1, col])
16+
ψ.AL[row + 1, col], ψ.C[row + 1, col] = leftorth!.AC[row + 1, col])
17+
end
18+
transfer_leftenv!(envs, ψ, toapprox, col + 1)
2519
end
2620

2721
# right to left sweep
28-
for col in size(ψ, 2):-1:1, row in 1:size(ψ, 1)
29-
h = MPO_∂∂AC(mpo[row, col], leftenv(envs, row, col),
30-
rightenv(envs, row, col))
31-
ψ.AC[row + 1, col] = h * above.AC[row, col]
32-
normalize!.AC[row + 1, col])
33-
34-
ψ.C[row + 1, col - 1], temp = rightorth(_transpose_tail.AC[row + 1, col]))
35-
ψ.AR[row + 1, col] = _transpose_front(temp)
36-
37-
tm = TransferMatrix(above.AR[row, col], mpo[row, col], ψ.AR[row + 1, col])
38-
setrightenv!(envs, row, col - 1, normalize(tm * rightenv(envs, row, col)))
22+
for col in size(ψ, 2):-1:1
23+
for row in 1:size(ψ, 1)
24+
ψ.AC[row + 1, col] = ac_proj(row, col, ψ, toapprox, envs)
25+
normalize!.AC[row + 1, col])
26+
ψ.C[row + 1, col - 1], temp = rightorth!(_transpose_tail.AC[row + 1,
27+
col]))
28+
ψ.AR[row + 1, col] = _transpose_front(temp)
29+
end
30+
transfer_rightenv!(envs, ψ, toapprox, col - 1)
3931
end
32+
normalize!(envs, ψ, toapprox)
4033

4134
ϵ = norm(C_current - ψ.C[:, 0])
4235

@@ -52,72 +45,62 @@ function approximate(ost::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:Multili
5245
end
5346
end
5447

48+
# TODO: immediately compute in-place
5549
nst = MultilineMPS(map(x -> x, ψ.AR); tol=alg.tol_gauge)
56-
nenvs = environments(nst, toapprox)
57-
return nst, nenvs, ϵ
50+
copy!(ψ, nst) # ensure output destination is unchanged
51+
52+
recalculate!(envs, ψ, toapprox)
53+
return ψ, envs, ϵ
5854
end
5955

60-
function approximate(ost::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
61-
alg::IDMRG2, oenvs=environments(ost, toapprox))
62-
length(ost) < 2 && throw(ArgumentError("unit cell should be >= 2"))
63-
mpo, above = toapprox
64-
ψ = copy(ost)
65-
envs = IDMRGEnvironments(ost, oenvs)
56+
function approximate!::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
57+
alg::IDMRG2, envs=environments(ψ, toapprox))
58+
size(ψ, 2) < 2 && throw(ArgumentError("unit cell should be >= 2"))
6659
ϵ::Float64 = 2 * alg.tol
6760
log = IterLog("IDMRG2")
61+
O, ϕ = toapprox
6862

6963
LoggingExtras.withlevel(; alg.verbosity) do
7064
@infov 2 loginit!(log, ϵ)
7165
for iter in 1:(alg.maxiter)
7266
C_current = ψ.C[:, 0]
7367

7468
# sweep from left to right
75-
for col in 1:size(ψ, 2), row in 1:size(ψ, 1)
76-
ac2 = above.AC[row, col] * _transpose_tail(above.AR[row, col + 1])
77-
h = MPO_∂∂AC2(mpo[row, col], mpo[row, col + 1], leftenv(envs, row, col),
78-
rightenv(envs, row, col + 1))
79-
80-
al, c, ar, = tsvd!(h * ac2; trunc=alg.trscheme, alg=TensorKit.SVD())
81-
normalize!(c)
82-
83-
ψ.AL[row + 1, col] = al
84-
ψ.C[row + 1, col] = complex(c)
85-
ψ.AR[row + 1, col + 1] = _transpose_front(ar)
86-
87-
setleftenv!(envs, row, col + 1,
88-
normalize(leftenv(envs, row, col) *
89-
TransferMatrix(above.AL[row, col], mpo[row, col],
90-
ψ.AL[row + 1, col])))
91-
setrightenv!(envs, row, col,
92-
normalize(TransferMatrix(above.AR[row, col + 1],
93-
mpo[row, col + 1],
94-
ψ.AR[row + 1, col + 1]) *
95-
rightenv(envs, row, col + 1)))
69+
for col in 1:size(ψ, 2)
70+
for row in 1:size(ψ, 1)
71+
AC2′ = ac2_proj(row, col, ψ, toapprox, envs)
72+
al, c, ar, = tsvd!(AC2′; trunc=alg.trscheme, alg=TensorKit.SVD())
73+
normalize!(c)
74+
75+
ψ.AL[row + 1, col] = al
76+
ψ.C[row + 1, col] = complex(c)
77+
ψ.AR[row + 1, col + 1] = _transpose_front(ar)
78+
ψ.AC[row + 1, col + 1] = _transpose_front(c * ar)
79+
end
80+
transfer_leftenv!(envs, ψ, toapprox, col + 1)
81+
transfer_rightenv!(envs, ψ, toapprox, col)
9682
end
83+
normalize!(envs, ψ, toapprox)
9784

9885
# sweep from right to left
99-
for col in (size(ψ, 2) - 1):-1:0, row in 1:size(ψ, 1)
100-
ac2 = above.AL[row, col] * _transpose_tail(above.AC[row, col + 1])
101-
h = MPO_∂∂AC2(mpo[row, col], mpo[row, col + 1], leftenv(envs, row, col),
102-
rightenv(envs, row, col + 1))
103-
104-
al, c, ar, = tsvd!(h * ac2; trunc=alg.trscheme, alg=TensorKit.SVD())
105-
normalize!(c)
106-
107-
ψ.AL[row + 1, col] = al
108-
ψ.C[row + 1, col] = complex(c)
109-
ψ.AR[row + 1, col + 1] = _transpose_front(ar)
110-
111-
setleftenv!(envs, row, col + 1,
112-
normalize(leftenv(envs, row, col) *
113-
TransferMatrix(above.AL[row, col], mpo[row, col],
114-
ψ.AL[row + 1, col])))
115-
setrightenv!(envs, row, col,
116-
normalize(TransferMatrix(above.AR[row, col + 1],
117-
mpo[row, col + 1],
118-
ψ.AR[row + 1, col + 1]) *
119-
rightenv(envs, row, col + 1)))
86+
for col in (size(ψ, 2) - 1):-1:0
87+
for row in 1:size(ψ, 1)
88+
# TODO: also write this as ac2_proj?
89+
AC2 = ϕ.AL[row, col] * _transpose_tail.AC[row, col + 1])
90+
AC2′ = ∂AC2(AC2, O[row, col], O[row, col + 1],
91+
leftenv(envs[row], col, ψ[row]),
92+
rightenv(envs[row], col, ψ[row]))
93+
al, c, ar, = tsvd!(AC2′; trunc=alg.trscheme, alg=TensorKit.SVD())
94+
normalize!(c)
95+
96+
ψ.AL[row + 1, col] = al
97+
ψ.C[row + 1, col] = complex(c)
98+
ψ.AR[row + 1, col + 1] = _transpose_front(ar)
99+
end
100+
transfer_leftenv!(envs, ψ, toapprox, col + 1)
101+
transfer_rightenv!(envs, ψ, toapprox, col)
120102
end
103+
normalize!(envs, ψ, toapprox)
121104

122105
# update error
123106
ϵ = sum(zip(C_current, ψ.C[:, 0])) do (c1, c2)
@@ -139,7 +122,10 @@ function approximate(ost::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:Multili
139122
end
140123
end
141124

125+
# TODO: immediately compute in-place
142126
nst = MultilineMPS(map(x -> x, ψ.AR); tol=alg.tol_gauge)
143-
nenvs = environments(nst, toapprox)
144-
return nst, nenvs, ϵ
127+
copy!(ψ, nst) # ensure output destination is unchanged
128+
recalculate!(envs, ψ, toapprox)
129+
130+
return ψ, envs, ϵ
145131
end

src/algorithms/approximate/vomps.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ function approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:Multilin
1111
temp_ACs = similar.(ψ.AC)
1212
scheduler = Defaults.scheduler[]
1313
log = IterLog("VOMPS")
14+
alg_environments = updatetol(alg.alg_environments, 0, ϵ)
15+
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
1416

1517
LoggingExtras.withlevel(; alg.verbosity) do
1618
@infov 2 loginit!(log, ϵ)
@@ -23,7 +25,7 @@ function approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:Multilin
2325
ψ = MultilineMPS(temp_ACs, ψ.C[:, end]; alg_gauge.tol, alg_gauge.maxiter)
2426

2527
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
26-
recalculate!(envs, ψ; alg_environments.tol)
28+
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
2729

2830
ψ, envs = alg.finalize(iter, ψ, toapprox, envs)::Tuple{typeof(ψ),typeof(envs)}
2931

@@ -44,18 +46,20 @@ function approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:Multilin
4446
return ψ, envs, ϵ
4547
end
4648

47-
function _vomps_localupdate(loc, ψ, (O, ψ₀), envs, factalg=QRpos())
49+
function _vomps_localupdate(loc, ψ, , envs, factalg=QRpos())
4850
local tmp_AC, tmp_C
4951
if Defaults.scheduler[] isa SerialScheduler
50-
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
51-
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
52+
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
53+
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
5254
else
5355
@sync begin
5456
Threads.@spawn begin
55-
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
57+
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs)
58+
for row in 1:size(ψ, 1)], 1)
5659
end
5760
Threads.@spawn begin
58-
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
61+
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)],
62+
1)
5963
end
6064
end
6165
end

0 commit comments

Comments
 (0)