@@ -25,23 +25,6 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2525end
2626data_graph (bpc:: AbstractBeliefPropagationCache ) = data_graph (tensornetwork (bpc))
2727
28- function message_update (alg:: Algorithm"contract" , contract_list:: Vector{ITensor} )
29- sequence = contraction_sequence (contract_list; alg= alg. kwargs. sequence_alg)
30- updated_messages = contract (contract_list; sequence)
31- message_norm = norm (updated_messages)
32- if alg. kwargs. normalize && ! iszero (message_norm)
33- updated_messages /= message_norm
34- end
35- return ITensor[updated_messages]
36- end
37-
38- function message_update (alg:: Algorithm"adapt_update" , contract_list:: Vector{ITensor} )
39- adapted_contract_list = alg. kwargs. adapt .(contract_list)
40- updated_messages = message_update (alg. kwargs. alg, adapted_contract_list)
41- dtype = mapreduce (datatype, promote_type, contract_list)
42- return map (adapt (dtype), updated_messages)
43- end
44-
4528# TODO : Take `dot` without precontracting the messages to allow scaling to more complex messages
4629function message_diff (message_a:: Vector{ITensor} , message_b:: Vector{ITensor} )
4730 lhs, rhs = contract (message_a), contract (message_b)
@@ -144,22 +127,28 @@ function incoming_messages(
144127end
145128
146129# Adapt interface for changing device
147- function map_messages (f, bpc:: AbstractBeliefPropagationCache )
130+ function map_messages (
131+ f, bpc:: AbstractBeliefPropagationCache , pes= collect (keys (messages (bpc)))
132+ )
148133 bpc = copy (bpc)
149- for pe in keys ( messages (bpc))
134+ for pe in pes
150135 set_message! (bpc, pe, f .(message (bpc, pe)))
151136 end
152137 return bpc
153138end
154- function map_factors (f, bpc:: AbstractBeliefPropagationCache )
139+ function map_factors (f, bpc:: AbstractBeliefPropagationCache , vs = vertices (bpc) )
155140 bpc = copy (bpc)
156- for v in vertices (bpc)
141+ for v in vs
157142 @preserve_graph bpc[v] = f (bpc[v])
158143 end
159144 return bpc
160145end
161- adapt_messages (to, bpc:: AbstractBeliefPropagationCache ) = map_messages (adapt (to), bpc)
162- adapt_factors (to, bpc:: AbstractBeliefPropagationCache ) = map_factors (adapt (to), bpc)
146+ function adapt_messages (to, bpc:: AbstractBeliefPropagationCache , args... )
147+ map_messages (adapt (to), bpc, args... )
148+ end
149+ function adapt_factors (to, bpc:: AbstractBeliefPropagationCache , args... )
150+ map_factors (adapt (to), bpc, args... )
151+ end
163152
164153function Adapt. adapt_structure (to, bpc:: AbstractBeliefPropagationCache )
165154 bpc = adapt_messages (to, bpc)
@@ -257,33 +246,49 @@ function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
257246 return delete_messages (bpc, [pe])
258247end
259248
260- """
261- Compute message tensor as product of incoming mts and local state
262- """
263249function updated_message (
264- bpc:: AbstractBeliefPropagationCache ,
265- edge:: PartitionEdge ;
266- message_update_alg= default_message_update_alg (bpc),
267- kwargs... ,
250+ alg:: Algorithm"contract" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
268251)
269252 vertex = src (edge)
270253 incoming_ms = incoming_messages (bpc, vertex; ignore_edges= PartitionEdge[reverse (edge)])
271254 state = factors (bpc, vertex)
255+ contract_list = ITensor[incoming_ms; state]
256+ sequence = contraction_sequence (contract_list; alg= alg. kwargs. sequence_alg)
257+ updated_messages = contract (contract_list; sequence)
258+ message_norm = norm (updated_messages)
259+ if alg. kwargs. normalize && ! iszero (message_norm)
260+ updated_messages /= message_norm
261+ end
262+ return ITensor[updated_messages]
263+ end
272264
273- return message_update (message_update_alg, ITensor[incoming_ms; state]; kwargs... )
265+ function updated_message (
266+ alg:: Algorithm"adapt_update" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
267+ )
268+ incoming_pes = setdiff (
269+ boundary_partitionedges (bpc, [src (edge)]; dir= :in ), [reverse (edge)]
270+ )
271+ adapted_bpc = adapt_messages (alg. kwargs. adapt, bpc, incoming_pes)
272+ adapted_bpc = adapt_factors (alg. kwargs. adapt, bpc, vertices (bpc, src (edge)))
273+ updated_messages = updated_message (alg. kwargs. alg, adapted_bpc, edge)
274+ dtype = mapreduce (datatype, promote_type, message (bpc, edge))
275+ return map (adapt (dtype), updated_messages)
274276end
275277
276- function update (
277- alg:: Algorithm"bp" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge ; kwargs...
278+ function update_message (
279+ message_update_alg:: Algorithm ,
280+ bpc:: AbstractBeliefPropagationCache ,
281+ edge:: PartitionEdge ;
282+ kwargs... ,
278283)
279- return set_message (bpc, edge, updated_message (bpc, edge; kwargs... ))
284+ return set_message (bpc, edge, updated_message (message_update_alg, bpc, edge; kwargs... ))
280285end
281286
282287"""
283288Do a sequential update of the message tensors on `edges`
284289"""
285- function update (
286- alg:: Algorithm ,
290+ function update_one_iteration (
291+ alg:: Algorithm"bp" ,
287292 bpc:: AbstractBeliefPropagationCache ,
288293 edges:: Vector ;
289294 (update_diff!)= nothing ,
@@ -292,7 +297,7 @@ function update(
292297 bpc = copy (bpc)
293298 for e in edges
294299 prev_message = ! isnothing (update_diff!) ? message (bpc, e) : nothing
295- bpc = update (alg, bpc, e; kwargs... )
300+ bpc = update_message (alg. kwargs . message_update_alg , bpc, e; kwargs... )
296301 if ! isnothing (update_diff!)
297302 update_diff![] += message_diff (message (bpc, e), prev_message)
298303 end
@@ -305,15 +310,15 @@ Do parallel updates between groups of edges of all message tensors
305310Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
306311mts relevant to that group.
307312"""
308- function update (
313+ function update_one_iteration (
309314 alg:: Algorithm ,
310315 bpc:: AbstractBeliefPropagationCache ,
311316 edge_groups:: Vector{<:Vector{<:PartitionEdge}} ;
312317 kwargs... ,
313318)
314319 new_mts = empty (messages (bpc))
315320 for edges in edge_groups
316- bpc_t = update (alg, bpc, edges; kwargs... )
321+ bpc_t = update_one_iteration (alg. kwargs . message_update_alg , bpc, edges; kwargs... )
317322 for e in edges
318323 set! (new_mts, e, message (bpc_t, e))
319324 end
@@ -324,27 +329,17 @@ end
324329"""
325330More generic interface for update, with default params
326331"""
327- function update (
328- alg:: Algorithm ,
329- bpc:: AbstractBeliefPropagationCache ;
330- message_update_alg= default_message_update_alg (bpc),
331- kwargs... ,
332- )
332+ function update (alg:: Algorithm , bpc:: AbstractBeliefPropagationCache ; kwargs... )
333333 compute_error = ! isnothing (alg. kwargs. tol)
334334 if isnothing (alg. kwargs. maxiter)
335335 error (" You need to specify a number of iterations for BP!" )
336336 end
337337 for i in 1 : alg. kwargs. maxiter
338338 diff = compute_error ? Ref (0.0 ) : nothing
339- bpc = update (
340- alg,
341- bpc,
342- alg. kwargs. edge_sequence;
343- (update_diff!)= diff,
344- message_update_alg= set_default_kwargs (message_update_alg),
345- kwargs... ,
339+ bpc = update_one_iteration (
340+ alg, bpc, alg. kwargs. edge_sequence; (update_diff!)= diff, kwargs...
346341 )
347- if compute_error && (diff. x / length (edges )) <= alg. kwargs. tol
342+ if compute_error && (diff. x / length (alg . kwargs . edge_sequence )) <= alg. kwargs. tol
348343 if alg. kwargs. verbose
349344 println (" BP converged to desired precision after $i iterations." )
350345 end
@@ -355,7 +350,7 @@ function update(
355350end
356351
357352function update (bpc:: AbstractBeliefPropagationCache ; alg= default_update_alg (bpc), kwargs... )
358- return update (set_default_kwargs (alg, bpc), bpc; kwargs ... )
353+ return update (set_default_kwargs (Algorithm ( alg; kwargs ... ) , bpc), bpc)
359354end
360355
361356function rescale_messages (
0 commit comments