Skip to content

Commit b3c6bc8

Browse files
committed
Preserve message type on normalization
1 parent 88c51ca commit b3c6bc8

File tree

4 files changed

+7
-5
lines changed

4 files changed

+7
-5
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
3838
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
3939
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
4040
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
41-
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
4241

4342
[extensions]
4443
ITensorNetworksEinExprsExt = "EinExprs"

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ function rescale_partitions(
382382
isempty(pv_vs) && continue
383383

384384
vn = region_scalar(bpc, pv)
385-
s = isreal(vn) ? sign(vn) : 1.0
386-
vn = s * inv(vn^(1 / length(pv_vs)))
385+
s = isreal(vn) ? sign(vn) : 1
386+
vn = s * inv(vn^(typeof(vn)((1 / length(pv_vs)))))
387387
set!(vertices_weights, first(pv_vs), s*vn)
388388
for v in pv_vs[2:length(pv_vs)]
389389
set!(vertices_weights, v, vn)

src/caches/beliefpropagationcache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ function rescale_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:Partit
141141
n *= sign(n)
142142
end
143143

144-
sf = (1 / sqrt(n)) ^ (1 / length(me))
144+
sf = (1 / sqrt(n)) ^ (typeof(n)((1 / length(me))))
145145
set!(mts, pe, sf .* me)
146146
set!(mts, reverse(pe), sf .* mer)
147147
end

test/test_normalize.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ using ITensorNetworks:
44
QuadraticFormNetwork,
55
edge_scalars,
66
norm_sqr_network,
7+
messages,
78
random_tensornetwork,
9+
scalartype,
810
siteinds,
911
vertex_scalars,
1012
rescale
@@ -39,7 +41,7 @@ using Test: @test, @testset
3941

4042
g = named_grid((Lx, Ly))
4143
s = siteinds("S=1/2", g)
42-
x = random_tensornetwork(rng, s; link_space=χ)
44+
x = random_tensornetwork(rng, ComplexF32, s; link_space=χ)
4345

4446
ψ = normalize(x; alg="exact")
4547
@test scalar(norm_sqr_network(ψ); alg="exact") 1.0
@@ -49,6 +51,7 @@ using Test: @test, @testset
4951
x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true, cache_update_kwargs=(; maxiter=20)
5052
)
5153
ψIψ_bpc = ψIψ_bpc[]
54+
@test all(m -> scalartype(only(m)) == ComplexF32, messages(ψIψ_bpc))
5255
@test all(x -> x 1.0, edge_scalars(ψIψ_bpc))
5356
@test all(x -> x 1.0, vertex_scalars(ψIψ_bpc))
5457
@test scalar(QuadraticFormNetwork(ψ); alg="bp", cache_update_kwargs=(; maxiter=20)) 1.0

0 commit comments

Comments
 (0)