Skip to content

Commit 6f96133

Browse files
committed
update
1 parent c57bc8a commit 6f96133

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/TensorInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,6 @@ include("map.jl")
5151
include("mmap.jl")
5252
include("sampling.jl")
5353
include("cspmodels.jl")
54+
include("belief.jl")
5455

5556
end # module

src/belief.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
struct BPState{T, VT<:AbstractVector{T}}
2+
t2v::Vector{Vector{Int}} # a mapping from tensors to variables
3+
v2t::Vector{Vector{Int}} # a mapping from variables to tensors
4+
edges_vectors::Vector{Vector{VT}} # each tensor is associated with a vector of vectors, one for each neighbor
5+
end
6+
7+
function BPState(::Type{T}, n::Int, t2v::Vector{Vector{Int}}, size_dict::Dict{Int, Int}) where T
8+
v2t = [Int[] for _ in 1:n]
9+
edges_vectors = [Vector{VT}[] for _ in 1:n]
10+
for (i, edge) in enumerate(t2v)
11+
for v in edge
12+
push!(v2t[v], i)
13+
push!(edges_vectors[i], ones(T, size_dict[v]))
14+
end
15+
end
16+
return BPState(t2v, v2t, edges_vectors)
17+
end
18+
19+
# belief propagation, update the tensors on the edges of the tensor network
20+
function belief_propagation(tn::TensorNetworkModel{T}, bpstate::BPState{T}; max_iter::Int=100, tol::Float64=1e-6) where T
21+
# collect the messages from the neighbors
22+
messages = [similar(bpstate.edges_vectors[it]) for it in 1:length(bpstate.t2v)]
23+
for (it, vs) in enumerate(bpstate.t2v)
24+
for (iv, v) in enumerate(vs)
25+
messages[it][iv] = tn.tensors[v]
26+
end
27+
end
28+
# update the tensors on the edges of the tensor network
29+
for (it, vs) in enumerate(bpstate.t2v)
30+
# update the tensor
31+
for (iv, v) in enumerate(vs)
32+
bpstate.edges_vectors[it][iv] = zeros(T, size_dict[v])
33+
for (j, w) in enumerate(vs)
34+
if j != iv
35+
bpstate.edges_vectors[it][iv] += messages[j][iv] * messages[j][iv]
36+
end
37+
end
38+
end
39+
end
40+
end
41+
42+
function tensor_product()
43+
end
44+
45+
function belief_propagation(tn::TensorNetworkModel{T}) where T
46+
return belief_propagation(tn, BPState(T, OMEinsum.get_ixsv(tn.code), tn.size_dict))
47+
end

0 commit comments

Comments
 (0)