1- using DiagonalArrays: delta
21using Dictionaries: Dictionary, set!, delete!
32using Graphs: AbstractGraph, is_tree, connected_components
43using NamedGraphs. GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges
54using ITensorBase: ITensor, dim
6- using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype
75
86struct BeliefPropagationCache{V, N <: AbstractDataGraph{V} } < :
97 AbstractBeliefPropagationCache{V}
1311
1412messages (bp_cache:: BeliefPropagationCache ) = bp_cache. messages
1513network (bp_cache:: BeliefPropagationCache ) = bp_cache. network
16- default_messages () = Dictionary ()
1714
18- BeliefPropagationCache (network) = BeliefPropagationCache (network, default_messages ())
15+ BeliefPropagationCache (network) = BeliefPropagationCache (network, Dictionary ())
1916
2017function Base. copy (bp_cache:: BeliefPropagationCache )
2118 return BeliefPropagationCache (copy (network (bp_cache)), copy (messages (bp_cache)))
@@ -33,16 +30,15 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message)
3330 return bp_cache
3431end
3532
36- function message (bp_cache:: AbstractBeliefPropagationCache , edge:: AbstractEdge ; kwargs... )
33+ function message (bp_cache:: BeliefPropagationCache , edge:: AbstractEdge ; kwargs... )
3734 ms = messages (bp_cache)
3835 return get (() -> default_message (bp_cache, edge; kwargs... ), ms, edge)
3936end
4037
41- function messages (bp_cache:: AbstractBeliefPropagationCache , edges:: Vector{<:AbstractEdge} )
38+ function messages (bp_cache:: BeliefPropagationCache , edges:: Vector{<:AbstractEdge} )
4239 return [message (bp_cache, e) for e in edges]
4340end
4441
45- default_bp_maxiter (g:: AbstractGraph ) = is_tree (g) ? 1 : nothing
4642# Forward onto the network
4743for f in [
4844 :(Graphs. vertices),
@@ -62,11 +58,6 @@ for f in [
6258 end
6359end
6460
65- # TODO : Get subgraph working on an ITensorNetwork to overload this directly
66- function default_bp_edge_sequence (bp_cache:: BeliefPropagationCache )
67- return forest_cover_edge_sequence (underlying_graph (bp_cache))
68- end
69-
7061function factors (tn:: AbstractTensorNetwork , vertex)
7162 return [tn[vertex]]
7263end
@@ -91,33 +82,6 @@ function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge)
9182 return t
9283end
9384
94- # Algorithmic defaults
95- default_update_alg (bp_cache:: BeliefPropagationCache ) = " bp"
96- default_message_update_alg (bp_cache:: BeliefPropagationCache ) = " contract"
97- default_normalize (:: Algorithm"contract" ) = true
98- default_sequence_alg (:: Algorithm"contract" ) = " optimal"
99- function set_default_kwargs (alg:: Algorithm"contract" )
100- normalize = get (alg, :normalize , default_normalize (alg))
101- sequence_alg = get (alg, :sequence_alg , default_sequence_alg (alg))
102- return Algorithm (" contract" ; normalize, sequence_alg)
103- end
104- function set_default_kwargs (alg:: Algorithm"adapt_update" )
105- _alg = set_default_kwargs (get (alg, :alg , Algorithm (" contract" )))
106- return Algorithm (" adapt_update" ; adapt = alg. adapt, alg = _alg)
107- end
108- default_verbose (:: Algorithm"bp" ) = false
109- default_tol (:: Algorithm"bp" ) = nothing
110- function set_default_kwargs (alg:: Algorithm"bp" , bp_cache:: BeliefPropagationCache )
111- verbose = get (alg, :verbose , default_verbose (alg))
112- maxiter = get (alg, :maxiter , default_bp_maxiter (bp_cache))
113- edge_sequence = get (alg, :edge_sequence , default_bp_edge_sequence (bp_cache))
114- tol = get (alg, :tol , default_tol (alg))
115- message_update_alg = set_default_kwargs (
116- get (alg, :message_update_alg , Algorithm (default_message_update_alg (bp_cache)))
117- )
118- return Algorithm (" bp" ; verbose, maxiter, edge_sequence, tol, message_update_alg)
119- end
120-
12185# TODO : Update message etc should go here...
12286function updated_message (
12387 alg:: Algorithm"contract" , bp_cache:: BeliefPropagationCache , edge:: AbstractEdge
@@ -141,85 +105,21 @@ function updated_message(
141105 return updated_message
142106end
143107
144- function updated_message (
145- bp_cache:: BeliefPropagationCache ,
146- edge:: AbstractEdge ;
147- alg = default_message_update_alg (bpc),
148- kwargs... ,
108+ function default_algorithm (
109+ :: Type{<:Algorithm"contract"} ; normalize = true , sequence_alg = " optimal"
149110 )
150- return updated_message ( set_default_kwargs ( Algorithm (alg; kwargs ... )), bp_cache, edge )
111+ return Algorithm (" contract " ; normalize, sequence_alg )
151112end
152-
153- function update_message! (
154- message_update_alg:: Algorithm , bp_cache:: BeliefPropagationCache , edge:: AbstractEdge
113+ function default_algorithm (
114+ :: Type{<:Algorithm"adapt_update"} ; adapt, alg = default_algorithm (Algorithm " contract" )
155115 )
156- return setmessage! (bp_cache, edge, updated_message (message_update_alg, bp_cache, edge) )
116+ return Algorithm ( " adapt_update " ; adapt, alg )
157117end
158118
159- """
160- Do a sequential update of the message tensors on `edges`
161- """
162- function update_iteration (
163- alg:: Algorithm"bp" ,
164- bpc:: AbstractBeliefPropagationCache ,
165- edges:: Vector ;
166- (update_diff!) = nothing ,
167- )
168- bpc = copy (bpc)
169- for e in edges
170- prev_message = ! isnothing (update_diff!) ? message (bpc, e) : nothing
171- update_message! (alg. message_update_alg, bpc, e)
172- if ! isnothing (update_diff!)
173- update_diff![] += message_diff (message (bpc, e), prev_message)
174- end
175- end
176- return bpc
177- end
178-
179- """
180- Do parallel updates between groups of edges of all message tensors
181- Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
182- mts relevant to that group.
183- """
184- function update_iteration (
185- alg:: Algorithm"bp" ,
186- bpc:: AbstractBeliefPropagationCache ,
187- edge_groups:: Vector{<:Vector{<:AbstractEdge}} ;
188- (update_diff!) = nothing ,
119+ function update_message! (
120+ message_update_alg:: Algorithm , bpc:: BeliefPropagationCache , edge:: AbstractEdge
189121 )
190- new_mts = empty (messages (bpc))
191- for edges in edge_groups
192- bpc_t = update_iteration (alg. kwargs. message_update_alg, bpc, edges; (update_diff!))
193- for e in edges
194- set! (new_mts, e, message (bpc_t, e))
195- end
196- end
197- return set_messages (bpc, new_mts)
198- end
199-
200- """
201- More generic interface for update, with default params
202- """
203- function update (alg:: Algorithm"bp" , bpc:: AbstractBeliefPropagationCache )
204- compute_error = ! isnothing (alg. tol)
205- if isnothing (alg. maxiter)
206- error (" You need to specify a number of iterations for BP!" )
207- end
208- for i in 1 : alg. maxiter
209- diff = compute_error ? Ref (0.0 ) : nothing
210- bpc = update_iteration (alg, bpc, alg. edge_sequence; (update_diff!) = diff)
211- if compute_error && (diff. x / length (alg. edge_sequence)) <= alg. tol
212- if alg. verbose
213- println (" BP converged to desired precision after $i iterations." )
214- end
215- break
216- end
217- end
218- return bpc
219- end
220-
221- function update (bpc:: AbstractBeliefPropagationCache ; alg = default_update_alg (bpc), kwargs... )
222- return update (set_default_kwargs (Algorithm (alg; kwargs... ), bpc), bpc)
122+ return setmessage! (bpc, edge, updated_message (message_update_alg, bpc, edge))
223123end
224124
225125# Edge sequence stuff
@@ -234,4 +134,4 @@ function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root
234134 end
235135 end
236136 return edges
237- end
137+ end
0 commit comments