Skip to content

Commit 22e7dbb

Browse files
committed
Improvements
1 parent bc7c238 commit 22e7dbb

File tree

5 files changed

+98
-63
lines changed

5 files changed

+98
-63
lines changed

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,69 @@ function update(
283283
)
284284
return update(Algorithm(alg), bpc; kwargs...)
285285
end
286+
287+
function rescale_message(bp_cache::AbstractBeliefPropagationCache, partitionpair)
288+
return rescale_messages(bp_cache, typeof(partitionpair)[partitionpair])
289+
end
290+
291+
function rescale_messages(bp_cache::AbstractBeliefPropagationCache)
292+
return rescale_messages(bp_cache, partitionpairs(bp_cache))
293+
end
294+
295+
function rescale_partitions(
296+
bpc::AbstractBeliefPropagationCache,
297+
partitions::Vector;
298+
vs_to_rescale::Vector=collect(vertices(tensornetwork(bpc))),
299+
)
300+
tn = tensornetwork(bpc)
301+
for pv in partitions
302+
pv_vs = filter(v -> v vs_to_rescale, vertices(bpc, pv))
303+
304+
isempty(pv_vs) && continue
305+
306+
vn = region_scalar(bpc, pv)
307+
if isreal(vn)
308+
tn[first(pv_vs)] *= sign(vn)
309+
vn *= sign(vn)
310+
end
311+
312+
vn = vn^(1 / length(pv_vs))
313+
for v in pv_vs
314+
tn[v] /= vn
315+
end
316+
end
317+
318+
return bpc
319+
end
320+
321+
function rescale_partitions(bpc::AbstractBeliefPropagationCache; kwargs...)
322+
return rescale_partitions(bpc, collect(partitions(bpc)); kwargs...)
323+
end
324+
325+
function rescale_partition(bpc::AbstractBeliefPropagationCache, partition; kwargs...)
326+
return rescale_partitions(bpc, typeof(partition)[partition]; kwargs...)
327+
end
328+
329+
function rescale(bpc::AbstractBeliefPropagationCache; kwargs...)
330+
bpc = rescale_messages(bpc)
331+
bpc = rescale_partitions(bpc; kwargs...)
332+
return bpc
333+
end
334+
335+
function logscalar(bpc::AbstractBeliefPropagationCache)
336+
numerator_terms, denominator_terms = scalar_factors_quotient(bpc)
337+
numerator_terms =
338+
any(t -> real(t) < 0, numerator_terms) ? complex.(numerator_terms) : numerator_terms
339+
denominator_terms = if any(t -> real(t) < 0, denominator_terms)
340+
complex.(denominator_terms)
341+
else
342+
denominator_terms
343+
end
344+
345+
any(iszero, denominator_terms) && return -Inf
346+
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
347+
end
348+
349+
function ITensors.scalar(bpc::AbstractBeliefPropagationCache)
350+
return exp(logscalar(bpc))
351+
end

src/caches/beliefpropagationcache.jl

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,28 +107,20 @@ function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
107107
return contract(ts; sequence)[]
108108
end
109109

110-
function normalize_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
110+
function rescale_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
111111
bp_cache = copy(bp_cache)
112112
mts = messages(bp_cache)
113113
for pe in pes
114114
me, mer = only(mts[pe]), only(mts[reverse(pe)])
115115
me, mer = normalize(me), normalize(mer)
116116
n = dot(me, mer)
117-
if isreal(n) && n < 0
118-
set!(mts, pe, ITensor[(sgn(n) / sqrt(abs(n))) * me])
119-
set!(mts, reverse(pe), ITensor[(1 / sqrt(abs(n))) * mer])
120-
else
121-
set!(mts, pe, ITensor[(1 / sqrt(n)) * me])
122-
set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer])
117+
if isreal(n)
118+
me *= sign(n)
119+
n *= sign(n)
123120
end
121+
122+
set!(mts, pe, ITensor[(1 / sqrt(n)) * me])
123+
set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer])
124124
end
125125
return bp_cache
126126
end
127-
128-
function normalize_message(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
129-
return normalize_messages(bp_cache, PartitionEdge[pe])
130-
end
131-
132-
function normalize_messages(bp_cache::BeliefPropagationCache)
133-
return normalize_messages(bp_cache, partitionedges(partitioned_tensornetwork(bp_cache)))
134-
end

src/contract.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,7 @@ function logscalar(
6363
cache![] = update(cache![]; cache_update_kwargs...)
6464
end
6565

66-
numerator_terms, denominator_terms = scalar_factors_quotient(cache![])
67-
numerator_terms =
68-
any(t -> real(t) < 0, numerator_terms) ? complex.(numerator_terms) : numerator_terms
69-
denominator_terms = if any(t -> real(t) < 0, denominator_terms)
70-
complex.(denominator_terms)
71-
else
72-
denominator_terms
73-
end
74-
75-
any(iszero, denominator_terms) && return -Inf
76-
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
66+
return logscalar(cache![])
7767
end
7868

7969
function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...)

src/normalize.jl

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,28 @@ function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...)
55
end
66

77
function rescale(
8-
alg::Algorithm"exact", tn::AbstractITensorNetwork, vs=collect(vertices(tn)); kwargs...
8+
alg::Algorithm"exact",
9+
tn::AbstractITensorNetwork;
10+
vs_to_rescale=collect(vertices(tn)),
11+
kwargs...,
912
)
1013
logn = logscalar(alg, tn; kwargs...)
11-
c = 1.0 / (exp(logn / length(vs)))
14+
c = 1.0 / (exp(logn / length(vs_to_rescale)))
1215
tn = copy(tn)
13-
for v in vs
16+
for v in vs_to_rescale
1417
tn[v] *= c
1518
end
1619
return tn
1720
end
1821

1922
function rescale(
2023
alg::Algorithm,
21-
tn::AbstractITensorNetwork,
22-
vs=collect(vertices(tn));
24+
tn::AbstractITensorNetwork;
25+
vs_to_rescale=collect(vertices(tn)),
2326
(cache!)=nothing,
2427
cache_construction_kwargs=default_cache_construction_kwargs(alg, tn),
2528
update_cache=isnothing(cache!),
26-
cache_update_kwargs=default_cache_update_kwargs(cache!),
29+
cache_update_kwargs=default_cache_update_kwargs(alg),
2730
)
2831
if isnothing(cache!)
2932
cache! = Ref(cache(alg, tn; cache_construction_kwargs...))
@@ -33,29 +36,9 @@ function rescale(
3336
cache![] = update(cache![]; cache_update_kwargs...)
3437
end
3538

36-
tn = copy(tn)
37-
cache![] = normalize_messages(cache![])
38-
vertices_states = Dictionary()
39-
for pv in partitionvertices(cache![])
40-
pv_vs = filter(v -> v vs, vertices(cache![], pv))
41-
42-
isempty(pv_vs) && continue
39+
cache![] = rescale(cache![]; vs_to_rescale)
4340

44-
vn = region_scalar(cache![], pv)
45-
if isreal(vn) && vn < 0
46-
tn[first(pv_vs)] *= -1
47-
vn = abs(vn)
48-
end
49-
50-
vn = vn^(1 / length(pv_vs))
51-
for v in pv_vs
52-
tn[v] /= vn
53-
set!(vertices_states, v, tn[v])
54-
end
55-
end
56-
57-
cache![] = update_factors(cache![], vertices_states)
58-
return tn
41+
return tensornetwork(cache![])
5942
end
6043

6144
function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
@@ -65,9 +48,8 @@ end
6548
function LinearAlgebra.normalize(
6649
alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...
6750
)
68-
norm_tn = QuadraticFormNetwork(tn)
69-
vs = filter(v -> v operator_vertices(norm_tn), collect(vertices(norm_tn)))
70-
return ket_network(rescale(alg, norm_tn, vs; kwargs...))
51+
norm_tn = inner_network(tn, tn)
52+
return ket_network(rescale(alg, norm_tn; kwargs...))
7153
end
7254

7355
function LinearAlgebra.normalize(
@@ -77,15 +59,19 @@ function LinearAlgebra.normalize(
7759
cache_construction_function=tn ->
7860
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
7961
update_cache=isnothing(cache!),
80-
cache_update_kwargs=default_cache_update_kwargs(cache!),
62+
cache_update_kwargs=default_cache_update_kwargs(alg),
63+
cache_construction_kwargs=(;),
8164
)
82-
norm_tn = QuadraticFormNetwork(tn)
65+
norm_tn = inner_network(tn, tn)
8366
if isnothing(cache!)
84-
cache! = Ref(cache_construction_function(norm_tn))
67+
cache! = Ref(cache(alg, norm_tn; cache_construction_kwargs...))
8568
end
8669

87-
vs = filter(v -> v operator_vertices(norm_tn), collect(vertices(norm_tn)))
88-
norm_tn = rescale(alg, norm_tn, vs; cache!, update_cache, cache_update_kwargs)
70+
vs = collect(vertices(tn))
71+
vs_to_rescale = vcat(
72+
[ket_vertex(norm_tn, v) for v in vs], [bra_vertex(norm_tn, v) for v in vs]
73+
)
74+
norm_tn = rescale(alg, norm_tn; vs_to_rescale, cache!, update_cache, cache_update_kwargs)
8975

9076
return ket_network(norm_tn)
9177
end

test/test_normalize.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ using ITensorNetworks:
55
edge_scalars,
66
norm_sqr_network,
77
random_tensornetwork,
8+
siteinds,
89
vertex_scalars,
910
rescale
10-
using ITensors: dag, inner, siteinds, scalar
11+
using ITensors: dag, inner, scalar
1112
using Graphs: SimpleGraph, uniform_tree
1213
using LinearAlgebra: normalize
1314
using NamedGraphs: NamedGraph

0 commit comments

Comments
 (0)