1717num_tensors (bp:: BeliefPropgation ) = length (bp. t2v)
1818ProblemReductions. num_variables (bp:: BeliefPropgation ) = length (bp. v2t)
1919
20- function BeliefPropgation (nvars:: Int , t2v:: AbstractVector{Vector{Int}} , tensors:: AbstractVector{AbstractArray{T}} ) where T
20+ function BeliefPropgation (nvars:: Int , t2v:: AbstractVector{Vector{Int}} , tensors:: AbstractVector{AbstractArray{T}} ) where {T}
2121 # initialize the inverse mapping
2222 v2t = [Int[] for _ in 1 : nvars]
2323 for (i, edge) in enumerate (t2v)
@@ -33,11 +33,11 @@ $(TYPEDSIGNATURES)
3333
3434Construct a belief propagation object from a [`UAIModel`](@ref).
3535"""
36- function BeliefPropgation (uai:: UAIModel{T} ) where T
36+ function BeliefPropgation (uai:: UAIModel{T} ) where {T}
3737 return BeliefPropgation (uai. nvars, [collect (Int, f. vars) for f in uai. factors], AbstractArray{T}[f. vals for f in uai. factors])
3838end
3939
40- struct BPState{T, VT<: AbstractVector{T} }
40+ struct BPState{T, VT <: AbstractVector{T} }
4141 message_in:: Vector{Vector{VT}} # for each variable, we store the incoming messages
4242 message_out:: Vector{Vector{VT}} # the outgoing messages
4343end
@@ -91,12 +91,12 @@ function star_code(n::Int)
9191 ix1, ixrest = collect (1 : n), [[i] for i in 1 : n]
9292 ne = DynamicNestedEinsum ([DynamicNestedEinsum {Int} (1 ), DynamicNestedEinsum {Int} (2 )], DynamicEinCode ([ix1, ixrest[1 ]], collect (2 : n)))
9393 for i in 2 : n
94- ne = DynamicNestedEinsum ([ne, DynamicNestedEinsum {Int} (i + 1 )], DynamicEinCode ([ne. eins. iy, ixrest[i]], collect (i + 1 : n)))
94+ ne = DynamicNestedEinsum ([ne, DynamicNestedEinsum {Int} (i + 1 )], DynamicEinCode ([ne. eins. iy, ixrest[i]], collect ((i + 1 ) : n)))
9595 end
9696 return ne
9797end
9898
99- function initial_state (bp:: BeliefPropgation{T} ) where T
99+ function initial_state (bp:: BeliefPropgation{T} ) where {T}
100100 size_dict = OMEinsum. get_size_dict (bp. t2v, bp. tensors)
101101 edges_vectors = Vector{Vector{T}}[]
102102 for (i, tids) in enumerate (bp. v2t)
@@ -126,13 +126,13 @@ struct BPInfo
126126 converged:: Bool
127127 iterations:: Int
128128end
129- function belief_propagate! (bp:: BeliefPropgation , state:: BPState{T} ; max_iter:: Int = 100 , tol= 1e-6 , damping= 0.2 ) where T
129+ function belief_propagate! (bp:: BeliefPropgation , state:: BPState{T} ; max_iter:: Int = 100 , tol = 1e-6 , damping = 0.2 ) where {T}
130130 pre_message_in = deepcopy (state. message_in)
131131 for i in 1 : max_iter
132- collect_message! (bp, state; normalize= true )
133- process_message! (state; normalize= true , damping= damping)
132+ collect_message! (bp, state; normalize = true )
133+ process_message! (state; normalize = true , damping = damping)
134134 # check convergence
135- if all (iv -> all (it -> isapprox (state. message_in[iv][it], pre_message_in[iv][it], atol= tol), 1 : length (bp. v2t[iv])), 1 : num_variables (bp))
135+ if all (iv -> all (it -> isapprox (state. message_in[iv][it], pre_message_in[iv][it], atol = tol), 1 : length (bp. v2t[iv])), 1 : num_variables (bp))
136136 return BPInfo (true , i)
137137 end
138138 pre_message_in = deepcopy (state. message_in)
@@ -141,13 +141,13 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In
141141end
142142
143143# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
144- function contraction_results (state:: BPState{T} ) where T
144+ function contraction_results (state:: BPState{T} ) where {T}
145145 return [sum (reduce ((x, y) -> x .* y, mi)) for mi in state. message_in]
146146end
147147
148148"""
149149$(TYPEDSIGNATURES)
150150"""
151- function marginals (state:: BPState{T} ) where T
151+ function marginals (state:: BPState{T} ) where {T}
152152 return Dict ([v] => normalize! (reduce ((x, y) -> x .* y, mi), 1 ) for (v, mi) in enumerate (state. message_in))
153153end
0 commit comments