1- using Graphs: IsDirected
1+ using Graphs: Graphs, IsDirected
22using SplitApplyCombine: group
33using LinearAlgebra: diag, dot
44using ITensors: dir
@@ -66,7 +66,7 @@ function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; k
6666 return not_implemented ()
6767end
6868partitions (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
69- partitionpairs (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
69+ PartitionedGraphs . partitionedges (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
7070
7171function default_edge_sequence (
7272 bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
@@ -88,6 +88,10 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
8888 return unpartitioned_graph (partitioned_tensornetwork (bpc))
8989end
9090
91+ function setindex_preserve_graph! (bpc:: AbstractBeliefPropagationCache , args... )
92+ return setindex_preserve_graph! (tensornetwork (bpc), args... )
93+ end
94+
9195function factors (bpc:: AbstractBeliefPropagationCache , verts:: Vector )
9296 return ITensor[tensornetwork (bpc)[v] for v in verts]
9397end
@@ -107,7 +111,7 @@ function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc)
107111end
108112
109113function edge_scalars (
110- bpc:: AbstractBeliefPropagationCache , pes= partitionpairs (bpc); kwargs...
114+ bpc:: AbstractBeliefPropagationCache , pes= partitionedges (bpc); kwargs...
111115)
112116 return map (pe -> region_scalar (bpc, pe; kwargs... ), pes)
113117end
@@ -283,3 +287,79 @@ function update(
283287)
284288 return update (Algorithm (alg), bpc; kwargs... )
285289end
290+
291+ function scale! (bp_cache:: AbstractBeliefPropagationCache , args... )
292+ return scale! (tensornetwork (bp_cache), args... )
293+ end
294+
295+ function rescale_messages (
296+ bp_cache:: AbstractBeliefPropagationCache , partitionedge:: PartitionEdge
297+ )
298+ return rescale_messages (bp_cache, [partitionedge])
299+ end
300+
301+ function rescale_messages (bp_cache:: AbstractBeliefPropagationCache )
302+ return rescale_messages (bp_cache, partitionedges (bp_cache))
303+ end
304+
305+ function rescale_partitions (
306+ bpc:: AbstractBeliefPropagationCache ,
307+ partitions:: Vector ;
308+ verts:: Vector = vertices (bpc, partitions),
309+ )
310+ bpc = copy (bpc)
311+ tn = tensornetwork (bpc)
312+ norms = map (v -> inv (norm (tn[v])), verts)
313+ scale! (bpc, Dictionary (verts, norms))
314+
315+ vertices_weights = Dictionary ()
316+ for pv in partitions
317+ pv_vs = filter (v -> v ∈ verts, vertices (bpc, pv))
318+ isempty (pv_vs) && continue
319+
320+ vn = region_scalar (bpc, pv)
321+ s = isreal (vn) ? sign (vn) : 1.0
322+ vn = s * inv (vn^ (1 / length (pv_vs)))
323+ set! (vertices_weights, first (pv_vs), s* vn)
324+ for v in pv_vs[2 : length (pv_vs)]
325+ set! (vertices_weights, v, vn)
326+ end
327+ end
328+
329+ scale! (bpc, vertices_weights)
330+
331+ return bpc
332+ end
333+
334+ function rescale_partitions (bpc:: AbstractBeliefPropagationCache , args... ; kwargs... )
335+ return rescale_partitions (bpc, collect (partitions (bpc)), args... ; kwargs... )
336+ end
337+
338+ function rescale_partition (
339+ bpc:: AbstractBeliefPropagationCache , partition, args... ; kwargs...
340+ )
341+ return rescale_partitions (bpc, [partition], args... ; kwargs... )
342+ end
343+
344+ function rescale (bpc:: AbstractBeliefPropagationCache , args... ; kwargs... )
345+ bpc = rescale_messages (bpc)
346+ bpc = rescale_partitions (bpc, args... ; kwargs... )
347+ return bpc
348+ end
349+
350+ function logscalar (bpc:: AbstractBeliefPropagationCache )
351+ numerator_terms, denominator_terms = scalar_factors_quotient (bpc)
352+ if any (t -> real (t) < 0 , numerator_terms)
353+ numerator_terms = complex .(numerator_terms)
354+ end
355+ if any (t -> real (t) < 0 , denominator_terms)
356+ denominator_terms = complex .(denominator_terms)
357+ end
358+
359+ any (iszero, denominator_terms) && return - Inf
360+ return sum (log .(numerator_terms)) - sum (log .((denominator_terms)))
361+ end
362+
363+ function ITensors. scalar (bpc:: AbstractBeliefPropagationCache )
364+ return exp (logscalar (bpc))
365+ end
0 commit comments