Skip to content

Commit da56c33

Browse files
committed
Rescale local tensors first
1 parent f7c8733 commit da56c33

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ end
6565
function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...)
6666
return not_implemented()
6767
end
68+
function message_overlap(bpc::AbstractBeliefPropagationCache, partitionedge; kwargs...)
69+
return not_implemented()
70+
end
6871
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
6972
PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented()
7073

@@ -285,7 +288,7 @@ function update(
285288
end
286289

287290
function rescale_message(bp_cache::AbstractBeliefPropagationCache, partitionedge)
288-
return rescale_messages(bp_cache, typeof(partitionedge)[partitionedge])
291+
return rescale_messages(bp_cache, [partitionedge])
289292
end
290293

291294
function rescale_messages(bp_cache::AbstractBeliefPropagationCache)
@@ -303,15 +306,23 @@ function rescale_partitions(
303306

304307
isempty(pv_vs) && continue
305308

309+
for v in pv_vs
310+
t = tn[v]
311+
setindex_preserve_graph!(tn, t / norm(t), v)
312+
end
313+
306314
vn = region_scalar(bpc, pv)
307315
if isreal(vn)
308-
tn[first(pv_vs)] *= sign(vn)
316+
v = first(pv_vs)
317+
t = tn[v]
318+
setindex_preserve_graph!(tn, t * sign(vn), v)
309319
vn *= sign(vn)
310320
end
311321

312322
vn = vn^(1 / length(pv_vs))
313323
for v in pv_vs
314-
tn[v] /= vn
324+
t = tn[v]
325+
setindex_preserve_graph!(tn, t / vn, v)
315326
end
316327
end
317328

@@ -323,7 +334,7 @@ function rescale_partitions(bpc::AbstractBeliefPropagationCache; kwargs...)
323334
end
324335

325336
function rescale_partition(bpc::AbstractBeliefPropagationCache, partition; kwargs...)
326-
return rescale_partitions(bpc, typeof(partition)[partition]; kwargs...)
337+
return rescale_partitions(bpc, [partition]; kwargs...)
327338
end
328339

329340
function rescale(bpc::AbstractBeliefPropagationCache; kwargs...)

src/caches/beliefpropagationcache.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,27 @@ function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
109109
return contract(ts; sequence)[]
110110
end
111111

112+
function message_overlap(bpc::BeliefPropagationCache, partitionedge; kwargs...)
113+
me, mer = only(message(bpc, partitionedge)), only(message(bpc, reverse(partitionedge)))
114+
return dot(me, mer)
115+
end
116+
112117
function rescale_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
113118
bp_cache = copy(bp_cache)
114119
mts = messages(bp_cache)
115120
for pe in pes
116-
me, mer = only(mts[pe]), only(mts[reverse(pe)])
117-
me, mer = normalize(me), normalize(mer)
118-
n = dot(me, mer)
121+
me, mer = normalize.(mts[pe]), normalize.(mts[reverse(pe)])
122+
set!(mts, pe, me)
123+
set!(mts, reverse(pe), mer)
124+
n = message_overlap(bp_cache, pe)
119125
if isreal(n)
120-
me *= sign(n)
126+
me[1] *= sign(n)
121127
n *= sign(n)
122128
end
123129

124-
set!(mts, pe, ITensor[(1 / sqrt(n)) * me])
125-
set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer])
130+
sf = (1 / sqrt(n)) ^ (1 / length(me))
131+
set!(mts, pe, sf .* me)
132+
set!(mts, reverse(pe), sf .* mer)
126133
end
127134
return bp_cache
128135
end

test/test_normalize.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using LinearAlgebra: normalize
1414
using NamedGraphs: NamedGraph
1515
using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree
1616
using StableRNGs: StableRNG
17+
using TensorOperations: TensorOperations
1718
using Test: @test, @testset
1819
@testset "Normalize" begin
1920

0 commit comments

Comments
 (0)