1- using Graphs: IsDirected
1+ using Graphs: Graphs, IsDirected
22using SplitApplyCombine: group
33using LinearAlgebra: diag, dot
44using ITensors: dir
@@ -88,6 +88,10 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
8888 return unpartitioned_graph (partitioned_tensornetwork (bpc))
8989end
9090
91+ function setindex_preserve_graph! (bpc:: AbstractBeliefPropagationCache , args... )
92+ return setindex_preserve_graph! (tensornetwork (bpc), args... )
93+ end
94+
9195function factors (bpc:: AbstractBeliefPropagationCache , verts:: Vector )
9296 return ITensor[tensornetwork (bpc)[v] for v in verts]
9397end
@@ -284,6 +288,10 @@ function update(
284288 return update (Algorithm (alg), bpc; kwargs... )
285289end
286290
291+ function scale! (bp_cache:: AbstractBeliefPropagationCache , args... )
292+ return scale! (tensornetwork (bp_cache), args... )
293+ end
294+
287295function rescale_messages (
288296 bp_cache:: AbstractBeliefPropagationCache , partitionedge:: PartitionEdge
289297)
@@ -297,48 +305,45 @@ end
297305function rescale_partitions (
298306 bpc:: AbstractBeliefPropagationCache ,
299307 partitions:: Vector ;
300- verts_to_rescale:: Vector = collect ( vertices (tensornetwork ( bpc)) ),
308+ verts_to_rescale:: Vector = vertices (bpc, partitions ),
301309)
310+ bpc = copy (bpc)
302311 tn = tensornetwork (bpc)
312+ norms = map (v -> inv (norm (tn[v])), verts_to_rescale)
313+ scale! (bpc, Dictionary (verts_to_rescale, norms))
314+
315+ vertices_weights = Dictionary ()
303316 for pv in partitions
304317 pv_vs = filter (v -> v ∈ verts_to_rescale, vertices (bpc, pv))
305-
306318 isempty (pv_vs) && continue
307319
308- for v in pv_vs
309- t = tn[v]
310- setindex_preserve_graph! (tn, t / norm (t), v)
311- end
312-
313320 vn = region_scalar (bpc, pv)
314- if isreal (vn)
315- v = first (pv_vs)
316- t = tn[v]
317- setindex_preserve_graph! (tn, t * sign (vn), v)
318- vn *= sign (vn)
319- end
320-
321- vn = vn^ (1 / length (pv_vs))
322- for v in pv_vs
323- t = tn[v]
324- setindex_preserve_graph! (tn, t / vn, v)
321+ s = isreal (vn) ? sign (vn) : 1.0
322+ vn = s * inv (vn^ (1 / length (pv_vs)))
323+ set! (vertices_weights, first (pv_vs), s* vn)
324+ for v in pv_vs[2 : length (pv_vs)]
325+ set! (vertices_weights, v, vn)
325326 end
326327 end
327328
329+ scale! (bpc, vertices_weights)
330+
328331 return bpc
329332end
330333
331- function rescale_partitions (bpc:: AbstractBeliefPropagationCache ; kwargs... )
332- return rescale_partitions (bpc, collect (partitions (bpc)); kwargs... )
334+ function rescale_partitions (bpc:: AbstractBeliefPropagationCache , args ... ; kwargs... )
335+ return rescale_partitions (bpc, collect (partitions (bpc)), args ... ; kwargs... )
333336end
334337
335- function rescale_partition (bpc:: AbstractBeliefPropagationCache , partition; kwargs... )
336- return rescale_partitions (bpc, [partition]; kwargs... )
338+ function rescale_partition (
339+ bpc:: AbstractBeliefPropagationCache , partition, args... ; kwargs...
340+ )
341+ return rescale_partitions (bpc, [partition], args... ; kwargs... )
337342end
338343
339- function rescale (bpc:: AbstractBeliefPropagationCache ; kwargs... )
344+ function rescale (bpc:: AbstractBeliefPropagationCache , args ... ; kwargs... )
340345 bpc = rescale_messages (bpc)
341- bpc = rescale_partitions (bpc; kwargs... )
346+ bpc = rescale_partitions (bpc, args ... ; kwargs... )
342347 return bpc
343348end
344349
0 commit comments