Skip to content

Commit 2e36def

Browse files
authored
Report Approximate Truncation Error from Apply Function (#232)
1 parent 60abbaf commit 2e36def

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.13.4"
4+
version = "0.13.5"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/apply.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function full_update_bp(
3434
nfullupdatesweeps=10,
3535
print_fidelity_loss=false,
3636
envisposdef=false,
37-
(singular_values!)=nothing,
37+
callback=Returns(nothing),
3838
symmetrize=false,
3939
apply_kwargs...,
4040
)
@@ -65,21 +65,23 @@ function full_update_bp(
6565
apply_kwargs...,
6666
)
6767
if symmetrize
68-
Rᵥ₁, Rᵥ₂ = factorize_svd(
68+
singular_values! = Ref(ITensor())
69+
Rᵥ₁, Rᵥ₂, spec = factorize_svd(
6970
Rᵥ₁ * Rᵥ₂,
7071
inds(Rᵥ₁);
7172
ortho="none",
7273
tags=edge_tag(v⃗[1] => v⃗[2]),
7374
singular_values!,
7475
apply_kwargs...,
7576
)
77+
callback(; singular_values=singular_values![], truncation_error=spec.truncerr)
7678
end
7779
ψᵥ₁ = Qᵥ₁ * Rᵥ₁
7880
ψᵥ₂ = Qᵥ₂ * Rᵥ₂
7981
return ψᵥ₁, ψᵥ₂
8082
end
8183

82-
function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_kwargs...)
84+
function simple_update_bp_full(o, ψ, v⃗; envs, callback=Returns(nothing), apply_kwargs...)
8385
cutoff = 10 * eps(real(scalartype(ψ)))
8486
envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs)
8587
envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs)
@@ -116,9 +118,11 @@ function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, ap
116118
v1_inds = [v1_inds; siteinds(ψ, v⃗[1])]
117119
v2_inds = [v2_inds; siteinds(ψ, v⃗[2])]
118120
e = v⃗[1] => v⃗[2]
119-
ψᵥ₁, ψᵥ₂ = factorize_svd(
121+
singular_values! = Ref(ITensor())
122+
ψᵥ₁, ψᵥ₂, spec = factorize_svd(
120123
oψ, v1_inds; ortho="none", tags=edge_tag(e), singular_values!, apply_kwargs...
121124
)
125+
callback(; singular_values=singular_values![], truncation_error=spec.truncerr)
122126
for inv_sqrt_env_v1 in inv_sqrt_envs_v1
123127
ψᵥ₁ *= dag(inv_sqrt_env_v1)
124128
end
@@ -129,7 +133,7 @@ function simple_update_bp_full(o, ψ, v⃗; envs, (singular_values!)=nothing, ap
129133
end
130134

131135
# Reduced version
132-
function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_kwargs...)
136+
function simple_update_bp(o, ψ, v⃗; envs, callback=Returns(nothing), apply_kwargs...)
133137
cutoff = 10 * eps(real(scalartype(ψ)))
134138
envs_v1 = filter(env -> hascommoninds(env, ψ[v⃗[1]]), envs)
135139
envs_v2 = filter(env -> hascommoninds(env, ψ[v⃗[2]]), envs)
@@ -164,14 +168,16 @@ function simple_update_bp(o, ψ, v⃗; envs, (singular_values!)=nothing, apply_k
164168
rᵥ₂ = commoninds(Qᵥ₂, Rᵥ₂)
165169
oR = apply(o, Rᵥ₁ * Rᵥ₂)
166170
e = v⃗[1] => v⃗[2]
167-
Rᵥ₁, Rᵥ₂ = factorize_svd(
171+
singular_values! = Ref(ITensor())
172+
Rᵥ₁, Rᵥ₂, spec = factorize_svd(
168173
oR,
169174
unioninds(rᵥ₁, sᵥ₁);
170175
ortho="none",
171176
tags=edge_tag(e),
172177
singular_values!,
173178
apply_kwargs...,
174179
)
180+
callback(; singular_values=singular_values![], truncation_error=spec.truncerr)
175181
Qᵥ₁ = contract([Qᵥ₁; dag.(inv_sqrt_envs_v1)])
176182
Qᵥ₂ = contract([Qᵥ₂; dag.(inv_sqrt_envs_v2)])
177183
ψᵥ₁ = Qᵥ₁ * Rᵥ₁
@@ -188,7 +194,7 @@ function ITensors.apply(
188194
nfullupdatesweeps=10,
189195
print_fidelity_loss=false,
190196
envisposdef=false,
191-
(singular_values!)=nothing,
197+
callback=Returns(nothing),
192198
variational_optimization_only=false,
193199
symmetrize=false,
194200
reduced=true,
@@ -224,15 +230,15 @@ function ITensors.apply(
224230
nfullupdatesweeps,
225231
print_fidelity_loss,
226232
envisposdef,
227-
singular_values!,
233+
callback,
228234
symmetrize,
229235
apply_kwargs...,
230236
)
231237
else
232238
if reduced
233-
ψᵥ₁, ψᵥ₂ = simple_update_bp(o, ψ, v⃗; envs, singular_values!, apply_kwargs...)
239+
ψᵥ₁, ψᵥ₂ = simple_update_bp(o, ψ, v⃗; envs, callback, apply_kwargs...)
234240
else
235-
ψᵥ₁, ψᵥ₂ = simple_update_bp_full(o, ψ, v⃗; envs, singular_values!, apply_kwargs...)
241+
ψᵥ₁, ψᵥ₂ = simple_update_bp_full(o, ψ, v⃗; envs, callback, apply_kwargs...)
236242
end
237243
end
238244
if normalize

test/test_apply.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using ITensorNetworks:
1111
random_tensornetwork,
1212
siteinds,
1313
update
14-
using ITensors: ITensors, inner, op
14+
using ITensors: ITensors, ITensor, inner, op
1515
using NamedGraphs.NamedGraphGenerators: named_grid
1616
using NamedGraphs.PartitionedGraphs: PartitionVertex
1717
using SplitApplyCombine: group
@@ -39,9 +39,15 @@ using Test: @test, @testset
3939
envsGBP = environment(bp_cache, [(v1, "bra"), (v1, "ket"), (v2, "bra"), (v2, "ket")])
4040
inner_alg = "exact"
4141
ngates = 5
42+
truncerr = 0.0
43+
singular_values = ITensor()
44+
function callback(; singular_values, truncation_error)
45+
truncerr = truncation_error
46+
singular_values = singular_values
47+
end
4248
for i in 1:ngates
4349
o = op("RandomUnitary", s[v1]..., s[v2]...)
44-
ψOexact = apply(o, ψ; cutoff=1e-16)
50+
ψOexact = apply(o, ψ; cutoff=nothing)
4551
ψOSBP = apply(
4652
o,
4753
ψ;
@@ -50,6 +56,7 @@ using Test: @test, @testset
5056
normalize=true,
5157
print_fidelity_loss=true,
5258
envisposdef=true,
59+
callback,
5360
)
5461
ψOv = apply(o, ψv; maxdim=χ, normalize=true)
5562
ψOVidal_symm = ITensorNetwork(ψOv)
@@ -73,6 +80,7 @@ using Test: @test, @testset
7380
fGBP =
7481
inner(ψOGBP, ψOexact; alg=inner_alg) /
7582
sqrt(inner(ψOexact, ψOexact; alg=inner_alg) * inner(ψOGBP, ψOGBP; alg=inner_alg))
83+
@test !iszero(truncerr)
7684
@test real(fGBP * conj(fGBP)) >= real(fSBP * conj(fSBP))
7785
@test isapprox(real(fSBP * conj(fSBP)), real(fVidal * conj(fVidal)); atol=1e-3)
7886
end

0 commit comments

Comments
 (0)