Skip to content

Commit 17ccf85

Browse files
committed
Centre around scaling function
1 parent 2df7e92 commit 17ccf85

File tree

4 files changed

+64
-45
lines changed

4 files changed

+64
-45
lines changed

src/abstractitensornetwork.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,26 @@ function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
935935
return tn12
936936
end
937937

938+
""" Scale each tensor of the network by a scale factor on each vertex"""
939+
function scale!(tn::AbstractITensorNetwork, vertices_weights::Dictionary)
940+
for v in keys(vertices_weights)
941+
setindex_preserve_graph!(tn, vertices_weights[v] * tn[v], v)
942+
end
943+
return tn
944+
end
945+
946+
""" Scale each tensor of the network via a function (vertex, ITensor) -> Number"""
947+
function scale!(tn::AbstractITensorNetwork, weight_function::Function)
948+
vs = collect(vertices(tn))
949+
vertices_weights = Dictionary(vs, [weight_function(v, tn[v]) for v in vs])
950+
return scale!(tn, vertices_weights)
951+
end
952+
953+
function scale(tn, args...)
954+
tn = copy(tn)
955+
return scale!(tn, args...)
956+
end
957+
938958
Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)
939959

940960
ITensors.hasqns(tn::AbstractITensorNetwork) = any(v -> hasqns(tn[v]), vertices(tn))

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Graphs: IsDirected
1+
using Graphs: Graphs, IsDirected
22
using SplitApplyCombine: group
33
using LinearAlgebra: diag, dot
44
using ITensors: dir
@@ -88,6 +88,10 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
8888
return unpartitioned_graph(partitioned_tensornetwork(bpc))
8989
end
9090

91+
function setindex_preserve_graph!(bpc::AbstractBeliefPropagationCache, args...)
92+
return setindex_preserve_graph!(tensornetwork(bpc), args...)
93+
end
94+
9195
function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
9296
return ITensor[tensornetwork(bpc)[v] for v in verts]
9397
end
@@ -284,6 +288,10 @@ function update(
284288
return update(Algorithm(alg), bpc; kwargs...)
285289
end
286290

291+
function scale!(bp_cache::AbstractBeliefPropagationCache, args...)
292+
return scale!(tensornetwork(bp_cache), args...)
293+
end
294+
287295
function rescale_messages(
288296
bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge
289297
)
@@ -297,48 +305,45 @@ end
297305
function rescale_partitions(
298306
bpc::AbstractBeliefPropagationCache,
299307
partitions::Vector;
300-
verts_to_rescale::Vector=collect(vertices(tensornetwork(bpc))),
308+
verts_to_rescale::Vector=vertices(bpc, partitions),
301309
)
310+
bpc = copy(bpc)
302311
tn = tensornetwork(bpc)
312+
norms = map(v -> inv(norm(tn[v])), verts_to_rescale)
313+
scale!(bpc, Dictionary(verts_to_rescale, norms))
314+
315+
vertices_weights = Dictionary()
303316
for pv in partitions
304317
pv_vs = filter(v -> v verts_to_rescale, vertices(bpc, pv))
305-
306318
isempty(pv_vs) && continue
307319

308-
for v in pv_vs
309-
t = tn[v]
310-
setindex_preserve_graph!(tn, t / norm(t), v)
311-
end
312-
313320
vn = region_scalar(bpc, pv)
314-
if isreal(vn)
315-
v = first(pv_vs)
316-
t = tn[v]
317-
setindex_preserve_graph!(tn, t * sign(vn), v)
318-
vn *= sign(vn)
319-
end
320-
321-
vn = vn^(1 / length(pv_vs))
322-
for v in pv_vs
323-
t = tn[v]
324-
setindex_preserve_graph!(tn, t / vn, v)
321+
s = isreal(vn) ? sign(vn) : 1.0
322+
vn = s * inv(vn^(1 / length(pv_vs)))
323+
set!(vertices_weights, first(pv_vs), s*vn)
324+
for v in pv_vs[2:length(pv_vs)]
325+
set!(vertices_weights, v, vn)
325326
end
326327
end
327328

329+
scale!(bpc, vertices_weights)
330+
328331
return bpc
329332
end
330333

331-
function rescale_partitions(bpc::AbstractBeliefPropagationCache; kwargs...)
332-
return rescale_partitions(bpc, collect(partitions(bpc)); kwargs...)
334+
function rescale_partitions(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
335+
return rescale_partitions(bpc, collect(partitions(bpc)), args...; kwargs...)
333336
end
334337

335-
function rescale_partition(bpc::AbstractBeliefPropagationCache, partition; kwargs...)
336-
return rescale_partitions(bpc, [partition]; kwargs...)
338+
function rescale_partition(
339+
bpc::AbstractBeliefPropagationCache, partition, args...; kwargs...
340+
)
341+
return rescale_partitions(bpc, [partition], args...; kwargs...)
337342
end
338343

339-
function rescale(bpc::AbstractBeliefPropagationCache; kwargs...)
344+
function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
340345
bpc = rescale_messages(bpc)
341-
bpc = rescale_partitions(bpc; kwargs...)
346+
bpc = rescale_partitions(bpc, args...; kwargs...)
342347
return bpc
343348
end
344349

src/caches/beliefpropagationcache.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ using NamedGraphs.PartitionedGraphs:
99
boundary_partitionedges,
1010
partitionvertices,
1111
partitionedges,
12-
unpartitioned_graph
12+
partitioned_vertices,
13+
unpartitioned_graph,
14+
which_partition
1315
using SimpleTraits: SimpleTraits, Not, @traitfn
1416
using NDTensors: NDTensors
1517

src/normalize.jl

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,18 @@ function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...)
44
return rescale(Algorithm(alg), tn; kwargs...)
55
end
66

7-
function rescale(
8-
alg::Algorithm"exact",
9-
tn::AbstractITensorNetwork;
10-
verts_to_rescale=collect(vertices(tn)),
11-
kwargs...,
12-
)
7+
function rescale(alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...)
138
logn = logscalar(alg, tn; kwargs...)
149
vs = collect(vertices(tn))
15-
c = inv(exp(logn / length(verts_to_rescale)))
16-
tn = copy(tn)
17-
for v in verts_to_rescale
18-
tn[v] *= c
19-
end
20-
return tn
10+
c = inv(exp(logn / length(vs)))
11+
vertices_weights = Dictionary(vs, [c for v in vs])
12+
return scale(tn, vertices_weights)
2113
end
2214

2315
function rescale(
2416
alg::Algorithm,
25-
tn::AbstractITensorNetwork;
17+
tn::AbstractITensorNetwork,
18+
args...;
2619
(cache!)=nothing,
2720
cache_construction_kwargs=default_cache_construction_kwargs(alg, tn),
2821
update_cache=isnothing(cache!),
@@ -37,7 +30,7 @@ function rescale(
3730
cache![] = update(cache![]; cache_update_kwargs...)
3831
end
3932

40-
cache![] = rescale(cache![]; kwargs...)
33+
cache![] = rescale(cache![], args...; kwargs...)
4134

4235
return tensornetwork(cache![])
4336
end
@@ -49,12 +42,11 @@ end
4942
function LinearAlgebra.normalize(
5043
alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...
5144
)
52-
norm_tn = inner_network(tn, tn)
45+
logn = logscalar(alg, inner_network(tn, tn); kwargs...)
5346
vs = collect(vertices(tn))
54-
verts_to_rescale = vcat(
55-
[ket_vertex(norm_tn, v) for v in vs], [bra_vertex(norm_tn, v) for v in vs]
56-
)
57-
return ket_network(rescale(alg, norm_tn; verts_to_rescale, kwargs...))
47+
c = inv(exp(logn / (2*length(vs))))
48+
vertices_weights = Dictionary(vs, [c for v in vs])
49+
return scale(tn, vertices_weights)
5850
end
5951

6052
function LinearAlgebra.normalize(

0 commit comments

Comments
 (0)