1- struct BPState{T, VT <: AbstractVector{T} }
1+ struct BeliefPropgation{T }
22 t2v:: Vector{Vector{Int}} # a mapping from tensors to variables
33 v2t:: Vector{Vector{Int}} # a mapping from variables to tensors
44 tensors:: Vector{AbstractArray{T}} # the tensors
5+ end
6+ num_tensors (bp:: BeliefPropgation ) = length (bp. t2v)
7+ ProblemReductions. num_variables (bp:: BeliefPropgation ) = length (bp. v2t)
8+
9+ function BeliefPropgation (nvars:: Int , t2v:: AbstractVector{Vector{Int}} , tensors:: AbstractVector{AbstractArray{T}} ) where T
10+ # initialize the inverse mapping
11+ v2t = [Int[] for _ in 1 : nvars]
12+ for (i, edge) in enumerate (t2v)
13+ for v in edge
14+ push! (v2t[v], i)
15+ end
16+ end
17+ return BeliefPropgation (t2v, v2t, tensors)
18+ end
19+ function BeliefPropgation (uai:: UAIModel{T} ) where T
20+ return BeliefPropgation (uai. nvars, [collect (Int, f. vars) for f in uai. factors], AbstractArray{T}[f. vals for f in uai. factors])
21+ end
22+
23+ struct BPState{T, VT<: AbstractVector{T} }
524 message_in:: Vector{Vector{VT}} # for each variable, we store the incoming messages
625 message_out:: Vector{Vector{VT}} # the outgoing messages
726end
@@ -22,20 +41,20 @@ function _process_message!(ov::Vector, iv::Vector)
2241 end
2342end
2443
25- function collect_message! (bp:: BPState )
26- for (it, t) in enumerate (bp. t2v )
27- _collect_message! (vectors_on_tensor (bp . message_out, bp, it), t , vectors_on_tensor (bp . message_in, bp, it))
44+ function collect_message! (bp:: BeliefPropgation , state :: BPState )
45+ for it in 1 : num_tensors (bp)
46+ _collect_message! (vectors_on_tensor (state . message_out, bp, it), bp . tensors[it] , vectors_on_tensor (state . message_in, bp, it))
2847 end
2948end
3049# collect the vectors associated with the target tensor
31- function vectors_on_tensor (messages, bp:: BPState , it:: Int )
50+ function vectors_on_tensor (messages, bp:: BeliefPropgation , it:: Int )
3251 return map (bp. t2v[it]) do v
3352 # the message goes to the idx-th tensor from variable v
3453 messages[v][findfirst (== (it), bp. v2t[v])]
3554 end
3655end
3756function _collect_message! (vectors_out:: Vector , t:: AbstractArray , vectors_in:: Vector )
38- @assert length (vectors_out) == length (vectors_in) == ndims (t)
57+ @assert length (vectors_out) == length (vectors_in) == ndims (t) " dimensions mismatch: $( length (vectors_out)) , $( length (vectors_in)) , $( ndims (t)) "
3958 # TODO : speed up if needed!
4059 code = star_code (length (vectors_in))
4160 cost, gradient = cost_and_gradient (code, [t, vectors_in... ])
@@ -44,6 +63,8 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
4463 end
4564 return cost
4665end
66+
67+ # star code: contract a tensor with multiple vectors, one for each dimension
4768function star_code (n:: Int )
4869 ix1, ixrest = collect (1 : n), [[i] for i in 1 : n]
4970 ne = DynamicNestedEinsum ([DynamicNestedEinsum {Int} (1 ), DynamicNestedEinsum {Int} (2 )], DynamicEinCode ([ix1, ixrest[1 ]], collect (2 : n)))
@@ -53,41 +74,38 @@ function star_code(n::Int)
5374 return ne
5475end
5576
56- function BPState (:: Type{T} , n:: Int , t2v:: Vector{Vector{Int}} , size_dict:: Dict{Int, Int} ) where T
57- v2t = [Int[] for _ in 1 : n]
58- edges_vectors = [Vector{VT}[] for _ in 1 : n]
59- for (i, edge) in enumerate (t2v)
60- for v in edge
61- push! (v2t[v], i)
62- push! (edges_vectors[i], ones (T, size_dict[v]))
63- end
77+ function initial_state (bp:: BeliefPropgation{T} ) where T
78+ size_dict = OMEinsum. get_size_dict (bp. t2v, bp. tensors)
79+ edges_vectors = Vector{Vector{T}}[]
80+ for (i, tids) in enumerate (bp. v2t)
81+ push! (edges_vectors, [ones (T, size_dict[i]) for _ in 1 : length (tids)])
6482 end
65- return BPState (t2v, v2t , edges_vectors)
83+ return BPState (deepcopy (edges_vectors) , edges_vectors)
6684end
6785
6886# belief propagation, update the tensors on the edges of the tensor network
69- function belief_propagation (tn:: TensorNetworkModel{T} , bpstate:: BPState{T} ; max_iter:: Int = 100 , tol:: Float64 = 1e-6 ) where T
70- # collect the messages from the neighbors
71- messages = [similar (bpstate. edges_vectors[it]) for it in 1 : length (bpstate. t2v)]
72- for (it, vs) in enumerate (bpstate. t2v)
73- for (iv, v) in enumerate (vs)
74- messages[it][iv] = tn. tensors[v]
75- end
76- end
77- # update the tensors on the edges of the tensor network
78- for (it, vs) in enumerate (bpstate. t2v)
79- # update the tensor
80- for (iv, v) in enumerate (vs)
81- bpstate. edges_vectors[it][iv] = zeros (T, size_dict[v])
82- for (j, w) in enumerate (vs)
83- if j != iv
84- bpstate. edges_vectors[it][iv] += messages[j][iv] * messages[j][iv]
85- end
86- end
87+ function belief_propagate (bp:: BeliefPropgation ; max_iter:: Int = 100 , tol:: Float64 = 1e-6 )
88+ state = initial_state (bp)
89+ info = belief_propagate! (bp, state; max_iter= max_iter, tol= tol)
90+ return state, info
91+ end
92+ struct BPInfo
93+ converged:: Bool
94+ iterations:: Int
95+ end
96+ function belief_propagate! (bp:: BeliefPropgation , state:: BPState{T} ; max_iter:: Int = 100 , tol:: Float64 = 1e-6 ) where T
97+ for i in 1 : max_iter
98+ process_message! (state)
99+ collect_message! (bp, state)
100+ # check convergence
101+ if all (iv -> all (it -> isapprox (state. message_out[iv][it], state. message_in[iv][it], atol= tol), 1 : length (bp. v2t[iv])), 1 : num_variables (bp))
102+ return BPInfo (true , i)
87103 end
88104 end
105+ return BPInfo (false , max_iter)
89106end
90107
91- function belief_propagation (tn:: TensorNetworkModel{T} ) where T
92- return belief_propagation (tn, BPState (T, OMEinsum. get_ixsv (tn. code), tn. size_dict))
93- end
108+ # if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
109+ function contraction_results (state:: BPState{T} ) where T
110+ return [sum (reduce ((x, y) -> x .* y, mi)) for mi in state. message_in]
111+ end
0 commit comments