Skip to content

Commit eeac23b

Browse files
committed
Working BP Commit
1 parent 8269686 commit eeac23b

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

src/ITensorNetworksNext.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,7 @@ module ITensorNetworksNext
33
include("abstracttensornetwork.jl")
44
include("tensornetwork.jl")
55

6+
include("beliefpropagation/abstractbeliefpropagationcache.jl")
7+
include("beliefpropagation/beliefpropagationcache.jl")
8+
69
end

src/abstracttensornetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,4 +276,4 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
276276
return nothing
277277
end
278278

279-
Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)
279+
Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)

test/test_beliefpropagation.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using Dictionaries: Dictionary
2+
using ITensorBase: Index
3+
using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy,
4+
partitionfunction
5+
using Graphs: edges, vertices
6+
using NamedGraphs.NamedGraphGenerators: named_grid
7+
using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges
8+
using Test: @test, @testset
9+
10+
@testset "BeliefPropagation" begin
11+
dims = (4, 1)
12+
g = named_grid(dims)
13+
l = Dict(e => Index(2) for e in edges(g))
14+
l = merge(l, Dict(reverse(e) => l[e] for e in edges(g)))
15+
tn = TensorNetwork(g) do v
16+
is = map(e -> l[e], incident_edges(g, v))
17+
return randn(Tuple(is))
18+
end
19+
20+
bpc = BeliefPropagationCache(tn)
21+
bpc = ITensorNetworksNext.update(bpc; maxiter = 10)
22+
z_bp = partitionfunction(bpc)
23+
z_exact = reduce(*, [tn[v] for v in vertices(g)])[]
24+
@test abs(z_bp - z_exact) <= 1e-14
25+
end

0 commit comments

Comments
 (0)