1+ using DiagonalArrays: delta
2+ using Dictionaries: Dictionary, set!, delete!
3+ using Graphs: AbstractGraph, is_tree, connected_components
4+ using NamedGraphs. GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges
5+ using ITensorBase: ITensor, dim
6+ using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype
7+
8+ struct BeliefPropagationCache{V, N <: AbstractDataGraph{V} } < :
9+ AbstractBeliefPropagationCache{V}
10+ network:: N
11+ messages:: Dictionary
12+ end
13+
14+ messages (bp_cache:: BeliefPropagationCache ) = bp_cache. messages
15+ network (bp_cache:: BeliefPropagationCache ) = bp_cache. network
16+ default_messages () = Dictionary ()
17+
18+ BeliefPropagationCache (network) = BeliefPropagationCache (network, default_messages ())
19+
20+ function Base. copy (bp_cache:: BeliefPropagationCache )
21+ return BeliefPropagationCache (copy (network (bp_cache)), copy (messages (bp_cache)))
22+ end
23+
24+ function deletemessage! (bp_cache:: BeliefPropagationCache , e:: AbstractEdge )
25+ ms = messages (bp_cache)
26+ delete! (ms, e)
27+ return bp_cache
28+ end
29+
30+ function setmessage! (bp_cache:: BeliefPropagationCache , e:: AbstractEdge , message)
31+ ms = messages (bp_cache)
32+ set! (ms, e, message)
33+ return bp_cache
34+ end
35+
36+ function message (bp_cache:: AbstractBeliefPropagationCache , edge:: AbstractEdge ; kwargs... )
37+ ms = messages (bp_cache)
38+ return get (() -> default_message (bp_cache, edge; kwargs... ), ms, edge)
39+ end
40+
41+ function messages (bp_cache:: AbstractBeliefPropagationCache , edges:: Vector{<:AbstractEdge} )
42+ return [message (bp_cache, e) for e in edges]
43+ end
44+
45+ default_bp_maxiter (g:: AbstractGraph ) = is_tree (g) ? 1 : nothing
46+ # Forward onto the network
47+ for f in [
48+ :(Graphs. vertices),
49+ :(Graphs. edges),
50+ :(Graphs. is_tree),
51+ :(NamedGraphs. GraphsExtensions. boundary_edges),
52+ :(factors),
53+ :(default_bp_maxiter),
54+ :(ITensorNetworksNext. setfactor!),
55+ :(ITensorNetworksNext. linkinds),
56+ :(ITensorNetworksNext. underlying_graph),
57+ ]
58+ @eval begin
59+ function $f (bp_cache:: BeliefPropagationCache , args... ; kwargs... )
60+ return $ f (network (bp_cache), args... ; kwargs... )
61+ end
62+ end
63+ end
64+
65+ # TODO : Get subgraph working on an ITensorNetwork to overload this directly
66+ function default_bp_edge_sequence (bp_cache:: BeliefPropagationCache )
67+ return forest_cover_edge_sequence (underlying_graph (bp_cache))
68+ end
69+
70+ function factors (tn:: AbstractTensorNetwork , vertex)
71+ return [tn[vertex]]
72+ end
73+
74+ function region_scalar (bp_cache:: BeliefPropagationCache , edge:: AbstractEdge )
75+ return (message (bp_cache, edge) * message (bp_cache, reverse (edge)))[]
76+ end
77+
78+ function region_scalar (bp_cache:: BeliefPropagationCache , vertex)
79+ incoming_ms = incoming_messages (bp_cache, vertex)
80+ state = factors (bp_cache, vertex)
81+ return (reduce (* , incoming_ms) * reduce (* , state))[]
82+ end
83+
84+ function default_message (bp_cache:: BeliefPropagationCache , edge:: AbstractEdge )
85+ return default_message (network (bp_cache), edge:: AbstractEdge )
86+ end
87+
88+ function default_message (tn:: AbstractTensorNetwork , edge:: AbstractEdge )
89+ t = ITensor (ones (dim .(linkinds (tn, edge))... ), linkinds (tn, edge)... )
90+ # TODO : Get datatype working on tensornetworks so we can support GPU, etc...
91+ return t
92+ end
93+
94+ # Algorithmic defaults
95+ default_update_alg (bp_cache:: BeliefPropagationCache ) = " bp"
96+ default_message_update_alg (bp_cache:: BeliefPropagationCache ) = " contract"
97+ default_normalize (:: Algorithm"contract" ) = true
98+ default_sequence_alg (:: Algorithm"contract" ) = " optimal"
99+ function set_default_kwargs (alg:: Algorithm"contract" )
100+ normalize = get (alg, :normalize , default_normalize (alg))
101+ sequence_alg = get (alg, :sequence_alg , default_sequence_alg (alg))
102+ return Algorithm (" contract" ; normalize, sequence_alg)
103+ end
104+ function set_default_kwargs (alg:: Algorithm"adapt_update" )
105+ _alg = set_default_kwargs (get (alg, :alg , Algorithm (" contract" )))
106+ return Algorithm (" adapt_update" ; adapt = alg. adapt, alg = _alg)
107+ end
108+ default_verbose (:: Algorithm"bp" ) = false
109+ default_tol (:: Algorithm"bp" ) = nothing
110+ function set_default_kwargs (alg:: Algorithm"bp" , bp_cache:: BeliefPropagationCache )
111+ verbose = get (alg, :verbose , default_verbose (alg))
112+ maxiter = get (alg, :maxiter , default_bp_maxiter (bp_cache))
113+ edge_sequence = get (alg, :edge_sequence , default_bp_edge_sequence (bp_cache))
114+ tol = get (alg, :tol , default_tol (alg))
115+ message_update_alg = set_default_kwargs (
116+ get (alg, :message_update_alg , Algorithm (default_message_update_alg (bp_cache)))
117+ )
118+ return Algorithm (" bp" ; verbose, maxiter, edge_sequence, tol, message_update_alg)
119+ end
120+
121+ # TODO : Update message etc should go here...
122+ function updated_message (
123+ alg:: Algorithm"contract" , bp_cache:: BeliefPropagationCache , edge:: AbstractEdge
124+ )
125+ vertex = src (edge)
126+ incoming_ms = incoming_messages (
127+ bp_cache, vertex; ignore_edges = typeof (edge)[reverse (edge)]
128+ )
129+ state = factors (bp_cache, vertex)
130+ # contract_list = ITensor[incoming_ms; state]
131+ # sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
132+ # updated_messages = contract(contract_list; sequence)
133+ updated_message =
134+ ! isempty (incoming_ms) ? reduce (* , state) * reduce (* , incoming_ms) : reduce (* , state)
135+ if alg. normalize
136+ message_norm = LinearAlgebra. norm (updated_message)
137+ if ! iszero (message_norm)
138+ updated_message /= message_norm
139+ end
140+ end
141+ return updated_message
142+ end
143+
144+ function updated_message (
145+ bp_cache:: BeliefPropagationCache ,
146+ edge:: AbstractEdge ;
147+ alg = default_message_update_alg (bpc),
148+ kwargs... ,
149+ )
150+ return updated_message (set_default_kwargs (Algorithm (alg; kwargs... )), bp_cache, edge)
151+ end
152+
153+ function update_message! (
154+ message_update_alg:: Algorithm , bp_cache:: BeliefPropagationCache , edge:: AbstractEdge
155+ )
156+ return setmessage! (bp_cache, edge, updated_message (message_update_alg, bp_cache, edge))
157+ end
158+
159+ """
160+ Do a sequential update of the message tensors on `edges`
161+ """
162+ function update_iteration (
163+ alg:: Algorithm"bp" ,
164+ bpc:: AbstractBeliefPropagationCache ,
165+ edges:: Vector ;
166+ (update_diff!) = nothing ,
167+ )
168+ bpc = copy (bpc)
169+ for e in edges
170+ prev_message = ! isnothing (update_diff!) ? message (bpc, e) : nothing
171+ update_message! (alg. message_update_alg, bpc, e)
172+ if ! isnothing (update_diff!)
173+ update_diff![] += message_diff (message (bpc, e), prev_message)
174+ end
175+ end
176+ return bpc
177+ end
178+
179+ """
180+ Do parallel updates between groups of edges of all message tensors
181+ Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
182+ mts relevant to that group.
183+ """
184+ function update_iteration (
185+ alg:: Algorithm"bp" ,
186+ bpc:: AbstractBeliefPropagationCache ,
187+ edge_groups:: Vector{<:Vector{<:AbstractEdge}} ;
188+ (update_diff!) = nothing ,
189+ )
190+ new_mts = empty (messages (bpc))
191+ for edges in edge_groups
192+ bpc_t = update_iteration (alg. kwargs. message_update_alg, bpc, edges; (update_diff!))
193+ for e in edges
194+ set! (new_mts, e, message (bpc_t, e))
195+ end
196+ end
197+ return set_messages (bpc, new_mts)
198+ end
199+
200+ """
201+ More generic interface for update, with default params
202+ """
203+ function update (alg:: Algorithm"bp" , bpc:: AbstractBeliefPropagationCache )
204+ compute_error = ! isnothing (alg. tol)
205+ if isnothing (alg. maxiter)
206+ error (" You need to specify a number of iterations for BP!" )
207+ end
208+ for i in 1 : alg. maxiter
209+ diff = compute_error ? Ref (0.0 ) : nothing
210+ bpc = update_iteration (alg, bpc, alg. edge_sequence; (update_diff!) = diff)
211+ if compute_error && (diff. x / length (alg. edge_sequence)) <= alg. tol
212+ if alg. verbose
213+ println (" BP converged to desired precision after $i iterations." )
214+ end
215+ break
216+ end
217+ end
218+ return bpc
219+ end
220+
221+ function update (bpc:: AbstractBeliefPropagationCache ; alg = default_update_alg (bpc), kwargs... )
222+ return update (set_default_kwargs (Algorithm (alg; kwargs... ), bpc), bpc)
223+ end
224+
225+ # Edge sequence stuff
226+ function forest_cover_edge_sequence (g:: AbstractGraph ; root_vertex = default_root_vertex)
227+ forests = forest_cover (g)
228+ edges = edgetype (g)[]
229+ for forest in forests
230+ trees = [forest[vs] for vs in connected_components (forest)]
231+ for tree in trees
232+ tree_edges = post_order_dfs_edges (tree, root_vertex (tree))
233+ push! (edges, vcat (tree_edges, reverse (reverse .(tree_edges)))... )
234+ end
235+ end
236+ return edges
237+ end
0 commit comments