Skip to content

Commit f88b21c

Browse files
committed
Merge remote-tracking branch 'upstream/main' into normalize!
2 parents 1c87d22 + 806d897 commit f88b21c

File tree

4 files changed

+61
-54
lines changed

4 files changed

+61
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.11.14"
4+
version = "0.11.15"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/caches/beliefpropagationcache.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Graphs: IsDirected
22
using SplitApplyCombine: group
3-
using LinearAlgebra: diag
3+
using LinearAlgebra: diag, dot
44
using ITensors: dir
55
using ITensorMPS: ITensorMPS
66
using NamedGraphs.PartitionedGraphs:
@@ -12,8 +12,9 @@ using NamedGraphs.PartitionedGraphs:
1212
partitionedges,
1313
unpartitioned_graph
1414
using SimpleTraits: SimpleTraits, Not, @traitfn
15+
using NDTensors: NDTensors
1516

16-
default_message(inds_e) = ITensor[denseblocks(delta(i)) for i in inds_e]
17+
default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
1718
default_messages(ptn::PartitionedGraph) = Dictionary()
1819
default_message_norm(m::ITensor) = norm(m)
1920
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
@@ -33,17 +34,16 @@ default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertice
3334
function default_partitioned_vertices(f::AbstractFormNetwork)
3435
return group(v -> original_state_vertex(f, v), vertices(f))
3536
end
36-
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)
37+
default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8)
3738
function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
3839
return (; partitioned_vertices=default_partitioned_vertices(ψ))
3940
end
4041

41-
function message_diff(
42-
message_a::Vector{ITensor}, message_b::Vector{ITensor}; message_norm=default_message_norm
43-
)
42+
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
43+
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
4444
lhs, rhs = contract(message_a), contract(message_b)
45-
norm_lhs, norm_rhs = message_norm(lhs), message_norm(rhs)
46-
return 0.5 * norm((denseblocks(lhs) / norm_lhs) - (denseblocks(rhs) / norm_rhs))
45+
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
46+
return 1 - f
4747
end
4848

4949
struct BeliefPropagationCache{PTN,MTS,DM}
@@ -99,8 +99,10 @@ for f in [
9999
end
100100
end
101101

102+
NDTensors.scalartype(bp_cache) = scalartype(tensornetwork(bp_cache))
103+
102104
function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
103-
return default_message(bp_cache)(linkinds(bp_cache, edge))
105+
return default_message(bp_cache)(scalartype(bp_cache), linkinds(bp_cache, edge))
104106
end
105107

106108
function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)

test/test_belief_propagation.jl

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ using ITensorNetworks:
2323
tensornetwork,
2424
update,
2525
update_factor,
26-
update_message
26+
update_message,
27+
message_diff
2728
using ITensors: ITensors, ITensor, combiner, dag, inds, inner, op, prime, random_itensor
2829
using ITensorNetworks.ModelNetworks: ModelNetworks
2930
using ITensors.NDTensors: array
@@ -34,50 +35,56 @@ using NamedGraphs.PartitionedGraphs: PartitionVertex, partitionedges
3435
using SplitApplyCombine: group
3536
using StableRNGs: StableRNG
3637
using Test: @test, @testset
37-
@testset "belief_propagation" begin
38-
ITensors.disable_warn_order()
39-
g = named_grid((3, 3))
40-
s = siteinds("S=1/2", g)
41-
χ = 2
42-
rng = StableRNG(1234)
43-
ψ = random_tensornetwork(rng, s; link_space=χ)
44-
ψψ = ψ prime(dag(ψ); sites=[])
45-
bpc = BeliefPropagationCache(ψψ)
46-
bpc = update(bpc; maxiter=50, tol=1e-10)
47-
#Test messages are converged
48-
for pe in partitionedges(partitioned_tensornetwork(bpc))
49-
@test update_message(bpc, pe) message(bpc, pe) atol = 1e-8
50-
end
51-
#Test updating the underlying tensornetwork in the cache
52-
v = first(vertices(ψψ))
53-
rng = StableRNG(1234)
54-
new_tensor = random_itensor(rng, inds(ψψ[v]))
55-
bpc_updated = update_factor(bpc, v, new_tensor)
56-
ψψ_updated = tensornetwork(bpc_updated)
57-
@test ψψ_updated[v] == new_tensor
5838

59-
#Test forming a two-site RDM. Check it has the correct size, trace 1 and is PSD
60-
vs = [(2, 2), (2, 3)]
39+
@testset "belief_propagation (eltype=$elt)" for elt in (
40+
Float32, Float64, Complex{Float32}, Complex{Float64}
41+
)
42+
begin
43+
ITensors.disable_warn_order()
44+
g = named_grid((3, 3))
45+
s = siteinds("S=1/2", g)
46+
χ = 2
47+
rng = StableRNG(1234)
48+
ψ = random_tensornetwork(rng, elt, s; link_space=χ)
49+
ψψ = ψ prime(dag(ψ); sites=[])
50+
bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ)))
51+
bpc = update(bpc; maxiter=25, tol=eps(real(elt)))
52+
#Test messages are converged
53+
for pe in partitionedges(partitioned_tensornetwork(bpc))
54+
@test message_diff(update_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt))
55+
@test eltype(only(message(bpc, pe))) == elt
56+
end
57+
#Test updating the underlying tensornetwork in the cache
58+
v = first(vertices(ψψ))
59+
rng = StableRNG(1234)
60+
new_tensor = random_itensor(rng, inds(ψψ[v]))
61+
bpc_updated = update_factor(bpc, v, new_tensor)
62+
ψψ_updated = tensornetwork(bpc_updated)
63+
@test ψψ_updated[v] == new_tensor
64+
65+
#Test forming a two-site RDM. Check it has the correct size, trace 1 and is PSD
66+
vs = [(2, 2), (2, 3)]
6167

62-
ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
63-
env_tensors = environment(bpc, [(v, 2) for v in vs])
64-
rdm = contract(vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]))
68+
ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
69+
env_tensors = environment(bpc, [(v, 2) for v in vs])
70+
rdm = contract(vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]))
6571

66-
rdm = array((rdm * combiner(inds(rdm; plev=0)...)) * combiner(inds(rdm; plev=1)...))
67-
rdm /= tr(rdm)
72+
rdm = array((rdm * combiner(inds(rdm; plev=0)...)) * combiner(inds(rdm; plev=1)...))
73+
rdm /= tr(rdm)
6874

69-
eigs = eigvals(rdm)
70-
@test size(rdm) == (2^length(vs), 2^length(vs))
75+
eigs = eigvals(rdm)
76+
@test size(rdm) == (2^length(vs), 2^length(vs))
7177

72-
@test all(eig -> imag(eig) 0, eigs)
73-
@test all(eig -> real(eig) >= -eps(eltype(eig)), eigs)
78+
@test all(eig -> abs(imag(eig)) <= eps(real(elt)), eigs)
79+
@test all(eig -> real(eig) >= -eps(real(elt)), eigs)
7480

75-
#Test edge case of network which evalutes to 0
76-
χ = 2
77-
g = named_grid((3, 1))
78-
rng = StableRNG(1234)
79-
ψ = random_tensornetwork(rng, ComplexF64, g; link_space=χ)
80-
ψ[(1, 1)] = 0.0 * ψ[(1, 1)]
81-
@test iszero(scalar(ψ; alg="bp"))
81+
#Test edge case of network which evalutes to 0
82+
χ = 2
83+
g = named_grid((3, 1))
84+
rng = StableRNG(1234)
85+
ψ = random_tensornetwork(rng, elt, g; link_space=χ)
86+
ψ[(1, 1)] = 0 * ψ[(1, 1)]
87+
@test iszero(scalar(ψ; alg="bp"))
88+
end
8289
end
8390
end

test/test_gauging.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ using Test: @test, @testset
2727
ψ = random_tensornetwork(rng, s; link_space=χ)
2828

2929
# Move directly to vidal gauge
30-
ψ_vidal = VidalITensorNetwork(
31-
ψ; cache_update_kwargs=(; maxiter=20, tol=1e-12, verbose=true)
32-
)
30+
ψ_vidal = VidalITensorNetwork(ψ; cache_update_kwargs=(; maxiter=30, verbose=true))
3331
@test gauge_error(ψ_vidal) < 1e-8
3432

3533
# Move to symmetric gauge
@@ -38,7 +36,7 @@ using Test: @test, @testset
3836
bp_cache = cache_ref[]
3937

4038
# Test we just did a gauge transform and didn't change the overall network
41-
@test inner(ψ_symm, ψ) / sqrt(inner(ψ_symm, ψ_symm) * inner(ψ, ψ)) 1.0
39+
@test inner(ψ_symm, ψ) / sqrt(inner(ψ_symm, ψ_symm) * inner(ψ, ψ)) 1.0 atol = 1e-8
4240

4341
#Test all message tensors are approximately diagonal even when we keep running BP
4442
bp_cache = update(bp_cache; maxiter=10)

0 commit comments

Comments
 (0)