Skip to content

Commit 620da37

Browse files
committed
Allow rescaling flat networks with bp
1 parent 4f4e2e5 commit 620da37

File tree

2 files changed

+71
-34
lines changed

2 files changed

+71
-34
lines changed

src/normalize.jl

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,31 @@
11
using LinearAlgebra
22

3-
function rescale(tn::AbstractITensorNetwork, c::Number, vs=collect(vertices(tn)))
3+
function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...)
4+
return rescale(Algorithm(alg), tn; kwargs...)
5+
end
6+
7+
function rescale(
8+
alg::Algorithm"exact", tn::AbstractITensorNetwork, vs=collect(vertices(tn)); kwargs...
9+
)
10+
logn = logscalar(alg, tn; kwargs...)
11+
c = 1.0 / (exp(logn / length(vs)))
412
tn = copy(tn)
513
for v in vs
614
tn[v] *= c
715
end
816
return tn
917
end
1018

11-
function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
12-
return normalize(Algorithm(alg), tn; kwargs...)
13-
end
14-
15-
function LinearAlgebra.normalize(alg::Algorithm"exact", tn::AbstractITensorNetwork)
16-
norm_tn = QuadraticFormNetwork(tn)
17-
c = exp(logscalar(alg, norm_tn) / (2 * length(vertices(tn))))
18-
return rescale(tn, 1 / c)
19-
end
20-
21-
function LinearAlgebra.normalize(
19+
function rescale(
2220
alg::Algorithm"bp",
23-
tn::AbstractITensorNetwork;
21+
tn::AbstractITensorNetwork,
22+
vs=collect(vertices(tn));
2423
(cache!)=nothing,
2524
update_cache=isnothing(cache!),
2625
cache_update_kwargs=default_cache_update_kwargs(cache!),
2726
)
2827
if isnothing(cache!)
29-
cache! = Ref(BeliefPropagationCache(QuadraticFormNetwork(tn)))
28+
cache! = Ref(BeliefPropagationCache(tn, group(v -> v, vertices(tn))))
3029
end
3130

3231
if update_cache
@@ -35,19 +34,57 @@ function LinearAlgebra.normalize(
3534

3635
tn = copy(tn)
3736
cache![] = normalize_messages(cache![])
38-
norm_tn = tensornetwork(cache![])
39-
4037
vertices_states = Dictionary()
41-
for v in vertices(tn)
42-
v_ket, v_bra = ket_vertex(norm_tn, v), bra_vertex(norm_tn, v)
43-
pv = only(partitionvertices(cache![], [v_ket]))
38+
for pv in partitionvertices(cache![])
39+
pv_vs = filter(v -> v vs, vertices(cache![], pv))
40+
41+
isempty(pv_vs) && continue
42+
4443
vn = region_scalar(cache![], pv)
45-
norm_tn = rescale(norm_tn, 1 / sqrt(vn), [v_ket, v_bra])
46-
set!(vertices_states, v_ket, norm_tn[v_ket])
47-
set!(vertices_states, v_bra, norm_tn[v_bra])
44+
if isreal(vn) && vn < 0
45+
tn[first(pv_vs)] *= -1
46+
vn = abs(vn)
47+
end
48+
49+
vn = vn^(1 / length(pv_vs))
50+
for v in pv_vs
51+
tn[v] /= vn
52+
set!(vertices_states, v, tn[v])
53+
end
4854
end
4955

5056
cache![] = update_factors(cache![], vertices_states)
57+
return tn
58+
end
59+
60+
function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
61+
return normalize(Algorithm(alg), tn; kwargs...)
62+
end
63+
64+
function LinearAlgebra.normalize(
65+
alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...
66+
)
67+
norm_tn = QuadraticFormNetwork(tn)
68+
vs = filter(v -> v operator_vertices(norm_tn), collect(vertices(norm_tn)))
69+
return ket_network(rescale(alg, norm_tn, vs; kwargs...))
70+
end
71+
72+
function LinearAlgebra.normalize(
73+
alg::Algorithm,
74+
tn::AbstractITensorNetwork;
75+
(cache!)=nothing,
76+
cache_construction_function=tn ->
77+
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
78+
update_cache=isnothing(cache!),
79+
cache_update_kwargs=default_cache_update_kwargs(cache!),
80+
)
81+
norm_tn = QuadraticFormNetwork(tn)
82+
if isnothing(cache!)
83+
cache! = Ref(cache_construction_function(norm_tn))
84+
end
85+
86+
vs = filter(v -> v operator_vertices(norm_tn), collect(vertices(norm_tn)))
87+
norm_tn = rescale(alg, norm_tn, vs; cache!, update_cache, cache_update_kwargs)
5188

5289
return ket_network(norm_tn)
5390
end

test/test_normalize.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,32 @@ using ITensorNetworks:
55
edge_scalars,
66
norm_sqr_network,
77
random_tensornetwork,
8-
vertex_scalars
8+
vertex_scalars,
9+
rescale
910
using ITensors: dag, inner, siteinds, scalar
1011
using Graphs: SimpleGraph, uniform_tree
1112
using LinearAlgebra: normalize
1213
using NamedGraphs: NamedGraph
13-
using NamedGraphs.NamedGraphGenerators: named_grid
14+
using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree
1415
using StableRNGs: StableRNG
1516
using Test: @test, @testset
1617
@testset "Normalize" begin
1718

18-
#First lets do a tree
19-
L = 6
19+
#First lets do a flat tree
20+
nx, ny = 2, 3
2021
χ = 2
2122
rng = StableRNG(1234)
2223

23-
g = NamedGraph(SimpleGraph(uniform_tree(L)))
24-
s = siteinds("S=1/2", g)
25-
x = random_tensornetwork(rng, s; link_space=χ)
24+
g = named_comb_tree((nx, ny))
25+
tn = random_tensornetwork(rng, g; link_space=χ)
2626

27-
ψ = normalize(x; alg="exact")
28-
@test scalar(norm_sqr_network(ψ); alg="exact") 1.0
27+
tn_r = rescale(tn; alg="exact")
28+
@test scalar(tn_r; alg="exact") 1.0
2929

30-
ψ = normalize(x; alg="bp")
31-
@test scalar(norm_sqr_network(ψ); alg="exact") 1.0
30+
tn_r = rescale(tn; alg="bp")
31+
@test scalar(tn_r; alg="exact") 1.0
3232

33-
#Now a loopy graph
33+
#Now a state on a loopy graph
3434
Lx, Ly = 3, 2
3535
χ = 2
3636
rng = StableRNG(1234)

0 commit comments

Comments
 (0)