Skip to content

Commit aa082fa

Browse files
committed
Merge branch 'upITensors' of github.com:JoeyT1994/ITensorNetworks.jl into upITensors
2 parents 7b1adf0 + aa96298 commit aa082fa

File tree

5 files changed

+470
-0
lines changed

5 files changed

+470
-0
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
using NamedGraphs.GraphsExtensions:
2+
add_edge, add_vertex, src, dst, vertices, subgraph, induced_subgraph, reverse, edges
3+
using NamedGraphs.NamedGraphGenerators: named_grid
4+
using NamedGraphs: NamedEdge
5+
using NamedGraphs.PartitionedGraphs:
6+
partitionvertex, PartitionEdge, partitionedges, partitionvertices
7+
using ITensors: siteinds
8+
using ITensorNetworks:
9+
random_tensornetwork,
10+
BeliefPropagationCache,
11+
QuadraticFormNetwork,
12+
update,
13+
environment,
14+
partitioned_tensornetwork,
15+
tensornetwork,
16+
eachtensor,
17+
ITensorNetwork,
18+
contraction_sequence,
19+
linkinds,
20+
message,
21+
factor,
22+
update_factor,
23+
messages,
24+
region_scalar,
25+
vertex_scalars,
26+
edge_scalars,
27+
norm_sqr_network,
28+
update_factors
29+
using ITensors:
30+
ITensor,
31+
contract,
32+
sim,
33+
replaceinds,
34+
combiner,
35+
combinedind,
36+
delta,
37+
Index,
38+
inds,
39+
replaceind,
40+
noprime,
41+
norm
42+
using OMEinsumContractionOrders: OMEinsumContractionOrders
43+
using Dictionaries: Dictionary, set!
44+
using LinearAlgebra: norm, LinearAlgebra, normalize, dot
45+
using SplitApplyCombine: group
46+
47+
using Random
48+
49+
function normalize_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
50+
bp_cache = copy(bp_cache)
51+
mts = messages(bp_cache)
52+
for pe in pes
53+
me, mer = only(mts[pe]), only(mts[reverse(pe)])
54+
set!(mts, pe, ITensor[me / norm(me)])
55+
set!(mts, reverse(pe), ITensor[mer / norm(mer)])
56+
n = region_scalar(bp_cache, pe)
57+
set!(mts, pe, ITensor[(1 / sqrt(n)) * me])
58+
set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer])
59+
end
60+
return bp_cache
61+
end
62+
63+
function normalize_message(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
64+
return normalize_messages(bp_cache, PartitionEdge[pe])
65+
end
66+
67+
function normalize_messages(bp_cache::BeliefPropagationCache)
68+
return normalize_messages(bp_cache, partitionedges(partitioned_tensornetwork(bp_cache)))
69+
end
70+
71+
function true_delta(row_inds::Vector{<:Index}, col_inds::Vector{<:Index})
72+
row_c, col_c = combiner(row_inds), combiner(col_inds)
73+
td = delta(combinedind(row_c), combinedind(col_c))
74+
return td * col_c * row_c
75+
end
76+
77+
function anti_project_edges(bpc::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
78+
bpc = copy(bpc)
79+
antiprojectors = ITensor[]
80+
for pe in pes
81+
indices = linkinds(bpc, pe)
82+
me, mer = only(message(bpc, pe)), only(message(bpc, reverse(pe)))
83+
dual_indices = [sim(noprime(ind)) for ind in indices]
84+
dual_inds_dict = Dictionary(indices, dual_indices)
85+
me = replaceinds(me, indices, dual_indices)
86+
anti_proj = true_delta(indices, dual_indices) - me * mer
87+
push!(antiprojectors, anti_proj)
88+
@assert inds(anti_proj) == vcat(indices, dual_indices)
89+
for v in vertices(bpc, dst(pe))
90+
ψv = only(factors(bpc, [v]))
91+
c_inds = intersect(inds(ψv), indices)
92+
for c in c_inds
93+
ψv = replaceind(ψv, c, dual_inds_dict[c])
94+
end
95+
bpc = update_factor(bpc, v, ψv)
96+
end
97+
end
98+
return bpc, antiprojectors
99+
end
100+
101+
function LinearAlgebra.normalize(
102+
ψ::ITensorNetwork; cache_update_kwargs=(; maxiter=30, tol=1e-12)
103+
)
104+
ψψ = norm_sqr_network(ψ)
105+
ψψ_bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ)))
106+
ψ, ψψ_bpc = normalize(ψ, ψψ_bpc; cache_update_kwargs)
107+
return ψ, ψψ_bpc
108+
end
109+
110+
function LinearAlgebra.normalize(
111+
ψ::ITensorNetwork,
112+
ψAψ_bpc::BeliefPropagationCache;
113+
cache_update_kwargs=default_cache_update_kwargs(ψAψ_bpc),
114+
update_cache=true,
115+
sf::Float64=1.0,
116+
)
117+
ψ = copy(ψ)
118+
if update_cache
119+
ψAψ_bpc = update(ψAψ_bpc; cache_update_kwargs...)
120+
end
121+
ψAψ_bpc = normalize_messages(ψAψ_bpc)
122+
ψψ = tensornetwork(ψAψ_bpc)
123+
124+
for v in vertices(ψ)
125+
v_ket, v_bra = (v, "ket"), (v, "bra")
126+
pv = only(partitionvertices(ψAψ_bpc, [v_ket]))
127+
vn = region_scalar(ψAψ_bpc, pv)
128+
state = copy(ψψ[v_ket]) / sqrt(sf * vn)
129+
state_dag = copy(ψψ[v_bra]) / sqrt(sf * vn)
130+
vertices_states = Dictionary([v_ket, v_bra], [state, state_dag])
131+
ψAψ_bpc = update_factors(ψAψ_bpc, vertices_states)
132+
ψ[v] = state
133+
end
134+
135+
return ψ, ψAψ_bpc
136+
end
137+
138+
Random.seed!(1234)
139+
140+
g = named_grid((2, 2))
141+
g = add_vertex(g, (2, 3))
142+
g = add_edge(g, NamedEdge((2, 3) => (2, 2)))
143+
144+
s = siteinds("S=1/2", g)
145+
ψ = random_tensornetwork(s; link_space=2)
146+
ψ, _ = normalize(ψ)
147+
ψIψ_bpc = BeliefPropagationCache(QuadraticFormNetwork(ψ))
148+
ψIψ_bpc = update(ψIψ_bpc; maxiter=20)
149+
ψIψ_bpc = normalize_messages(ψIψ_bpc)
150+
bp_norm = prod(vertex_scalars(ψIψ_bpc))
151+
pg = partitioned_tensornetwork(ψIψ_bpc)
152+
153+
loop =
154+
PartitionEdge.([
155+
NamedEdge((1, 1) => (1, 2)),
156+
NamedEdge((1, 2) => (2, 2)),
157+
NamedEdge((2, 2) => (2, 1)),
158+
NamedEdge((2, 1) => (1, 1)),
159+
])
160+
partition_vertices_in_loop = unique(vcat(src.(loop), dst.(loop)))
161+
162+
incoming_messages = environment(ψIψ_bpc, partition_vertices_in_loop)
163+
bpc, antiprojectors = anti_project_edges(ψIψ_bpc, loop)
164+
tn = factors(bpc, vertices(bpc, partition_vertices_in_loop))
165+
166+
all_tensors = vcat(vcat(tn, antiprojectors), incoming_messages)
167+
seq = contraction_sequence(all_tensors; alg="sa_bipartite")
168+
loop_correction = contract(all_tensors; sequence=seq)[]
169+
170+
true_contraction = bp_norm + loop_correction
171+
@show true_contraction
172+
173+
ψIψ = QuadraticFormNetwork(ψ)
174+
all_tensors = [ψIψ[v] for v in vertices(ψIψ)]
175+
seq = contraction_sequence(all_tensors; alg="sa_bipartite")
176+
actual_contraction = contract(all_tensors; sequence=seq)[]
177+
@show actual_contraction

examples/test_beliefpropagation.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using ITensorNetworks:
2+
BoundaryMPSCache,
3+
BeliefPropagationCache,
4+
QuadraticFormNetwork,
5+
IndsNetwork,
6+
siteinds,
7+
ttn,
8+
random_tensornetwork,
9+
partitionedges,
10+
messages,
11+
update,
12+
partition_update,
13+
set_messages,
14+
message,
15+
planargraph_partitionedges,
16+
switch_messages,
17+
environment,
18+
VidalITensorNetwork,
19+
ITensorNetwork,
20+
expect,
21+
default_message_update,
22+
contraction_sequence,
23+
insert_linkinds,
24+
partitioned_tensornetwork,
25+
default_message,
26+
biorthogonalize,
27+
pe_above
28+
using OMEinsumContractionOrders
29+
using ITensorNetworks.ITensorsExtensions: map_eigvals
30+
using ITensorNetworks.ModelHamiltonians: ising
31+
using ITensors:
32+
ITensor,
33+
ITensors,
34+
Index,
35+
OpSum,
36+
terms,
37+
sites,
38+
contract,
39+
commonind,
40+
replaceind,
41+
replaceinds,
42+
prime,
43+
dag,
44+
noncommonind,
45+
noncommoninds,
46+
inds
47+
using NamedGraphs: NamedGraphs, AbstractGraph, NamedEdge, NamedGraph, vertices, neighbors
48+
using NamedGraphs.NamedGraphGenerators: named_grid, named_hexagonal_lattice_graph
49+
using NamedGraphs.GraphsExtensions: rem_vertex, add_edges, add_edge
50+
using NamedGraphs.PartitionedGraphs:
51+
PartitionedGraph,
52+
partitioned_graph,
53+
PartitionVertex,
54+
PartitionEdge,
55+
unpartitioned_graph,
56+
partitioned_vertices,
57+
which_partition
58+
using LinearAlgebra: normalize
59+
using Graphs: center
60+
61+
using Random
62+
63+
Random.seed!(1834)
64+
ITensors.disable_warn_order()
65+
66+
function main()
67+
L = 4
68+
#g = lieb_lattice_grid(L, L)
69+
#g = named_hexagonal_lattice_graph(L, L)
70+
#g = named_grid_periodic_x((L,2))
71+
g = named_grid((L, 3))
72+
vc = first(center(g))
73+
s = siteinds("S=1/2", g)
74+
ψ = random_tensornetwork(s; link_space=3)
75+
bp_update_kwargs = (; maxiter=50, tol=1e-14, verbose=true)
76+
77+
ψIψ = BeliefPropagationCache(QuadraticFormNetwork(ψ))
78+
ψIψ = update(ψIψ; bp_update_kwargs...)
79+
return ψIψ = update(ψIψ; bp_update_kwargs...)
80+
end
81+
82+
main()

0 commit comments

Comments
 (0)