1+ using Adapt: Adapt, adapt, adapt_structure
12using Graphs: Graphs, IsDirected
23using SplitApplyCombine: group
34using LinearAlgebra: diag, dot
@@ -24,24 +25,20 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2425end
2526data_graph (bpc:: AbstractBeliefPropagationCache ) = data_graph (tensornetwork (bpc))
2627
27- function default_message_update (contract_list:: Vector{ITensor} ; normalize= true , kwargs... )
28- sequence = contraction_sequence (contract_list; alg= " optimal" )
29- updated_messages = contract (contract_list; sequence, kwargs... )
30- message_norm = norm (updated_messages)
31- if normalize && ! iszero (message_norm)
32- updated_messages /= message_norm
33- end
34- return ITensor[updated_messages]
35- end
36-
3728# TODO : Take `dot` without precontracting the messages to allow scaling to more complex messages
3829function message_diff (message_a:: Vector{ITensor} , message_b:: Vector{ITensor} )
3930 lhs, rhs = contract (message_a), contract (message_b)
4031 f = abs2 (dot (lhs / norm (lhs), rhs / norm (rhs)))
4132 return 1 - f
4233end
4334
44- default_message (elt, inds_e) = ITensor[denseblocks (delta (elt, i)) for i in inds_e]
35+ function default_message (datatype:: Type{<:AbstractArray} , inds_e)
36+ return [adapt (datatype, denseblocks (delta (i))) for i in inds_e]
37+ end
38+
39+ function default_message (elt:: Type{<:Number} , inds_e)
40+ return default_message (Vector{elt}, inds_e)
41+ end
4542default_messages (ptn:: PartitionedGraph ) = Dictionary ()
4643@traitfn default_bp_maxiter (g: :: :(! IsDirected)) = is_tree (g) ? 1 : nothing
4744@traitfn function default_bp_maxiter (g: :: :IsDirected )
@@ -59,15 +56,13 @@ function default_message(
5956)
6057 return not_implemented ()
6158end
59+ default_update_alg (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
6260default_message_update_alg (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
6361Base. copy (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
6462default_bp_maxiter (alg:: Algorithm , bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
6563function default_edge_sequence (alg:: Algorithm , bpc:: AbstractBeliefPropagationCache )
6664 return not_implemented ()
6765end
68- function default_message_update_kwargs (alg:: Algorithm , bpc:: AbstractBeliefPropagationCache )
69- return not_implemented ()
70- end
7166function environment (bpc:: AbstractBeliefPropagationCache , verts:: Vector ; kwargs... )
7267 return not_implemented ()
7368end
8075partitions (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
8176PartitionedGraphs. partitionedges (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
8277
83- function default_edge_sequence (
84- bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
85- )
86- return default_edge_sequence (Algorithm (alg), bpc)
87- end
88- function default_bp_maxiter (
89- bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
90- )
91- return default_bp_maxiter (Algorithm (alg), bpc)
92- end
93- function default_message_update_kwargs (
94- bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
95- )
96- return default_message_update_kwargs (Algorithm (alg), bpc)
97- end
78+ default_bp_edge_sequence (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
79+ default_bp_maxiter (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
9880
9981function tensornetwork (bpc:: AbstractBeliefPropagationCache )
10082 return unpartitioned_graph (partitioned_tensornetwork (bpc))
@@ -144,6 +126,36 @@ function incoming_messages(
144126 return incoming_messages (bpc, [partition_vertex]; kwargs... )
145127end
146128
129+ # Adapt interface for changing device
130+ function map_messages (
131+ f, bpc:: AbstractBeliefPropagationCache , pes= collect (keys (messages (bpc)))
132+ )
133+ bpc = copy (bpc)
134+ for pe in pes
135+ set_message! (bpc, pe, f .(message (bpc, pe)))
136+ end
137+ return bpc
138+ end
139+ function map_factors (f, bpc:: AbstractBeliefPropagationCache , vs= vertices (bpc))
140+ bpc = copy (bpc)
141+ for v in vs
142+ @preserve_graph bpc[v] = f (bpc[v])
143+ end
144+ return bpc
145+ end
146+ function adapt_messages (to, bpc:: AbstractBeliefPropagationCache , args... )
147+ return map_messages (adapt (to), bpc, args... )
148+ end
149+ function adapt_factors (to, bpc:: AbstractBeliefPropagationCache , args... )
150+ return map_factors (adapt (to), bpc, args... )
151+ end
152+
153+ function Adapt. adapt_structure (to, bpc:: AbstractBeliefPropagationCache )
154+ bpc = adapt_messages (to, bpc)
155+ bpc = adapt_factors (to, bpc)
156+ return bpc
157+ end
158+
147159# Forward from partitioned graph
148160for f in [
149161 :(PartitionedGraphs. partitioned_graph),
@@ -234,44 +246,63 @@ function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
234246 return delete_messages (bpc, [pe])
235247end
236248
237- """
238- Compute message tensor as product of incoming mts and local state
239- """
240249function updated_message (
241- bpc:: AbstractBeliefPropagationCache ,
242- edge:: PartitionEdge ;
243- message_update_function= default_message_update,
244- message_update_function_kwargs= (;),
250+ alg:: Algorithm"contract" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
245251)
246252 vertex = src (edge)
247253 incoming_ms = incoming_messages (bpc, vertex; ignore_edges= PartitionEdge[reverse (edge)])
248254 state = factors (bpc, vertex)
255+ contract_list = ITensor[incoming_ms; state]
256+ sequence = contraction_sequence (contract_list; alg= alg. kwargs. sequence_alg)
257+ updated_messages = contract (contract_list; sequence)
258+ message_norm = norm (updated_messages)
259+ if alg. kwargs. normalize && ! iszero (message_norm)
260+ updated_messages /= message_norm
261+ end
262+ return ITensor[updated_messages]
263+ end
249264
250- return message_update_function (
251- ITensor[incoming_ms; state]; message_update_function_kwargs...
265+ function updated_message (
266+ alg:: Algorithm"adapt_update" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
267+ )
268+ incoming_pes = setdiff (
269+ boundary_partitionedges (bpc, [src (edge)]; dir= :in ), [reverse (edge)]
252270 )
271+ adapted_bpc = adapt_messages (alg. kwargs. adapt, bpc, incoming_pes)
272+ adapted_bpc = adapt_factors (alg. kwargs. adapt, bpc, vertices (bpc, src (edge)))
273+ updated_messages = updated_message (alg. kwargs. alg, adapted_bpc, edge)
274+ dtype = mapreduce (datatype, promote_type, message (bpc, edge))
275+ return map (adapt (dtype), updated_messages)
276+ end
277+
278+ function updated_message (
279+ bpc:: AbstractBeliefPropagationCache ,
280+ edge:: PartitionEdge ;
281+ alg= default_message_update_alg (bpc),
282+ kwargs... ,
283+ )
284+ return updated_message (set_default_kwargs (Algorithm (alg; kwargs... )), bpc, edge)
253285end
254286
255- function update (
256- alg :: Algorithm"bp" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge ; kwargs ...
287+ function update_message (
288+ message_update_alg :: Algorithm , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
257289)
258- return set_message (bpc, edge, updated_message (bpc, edge; kwargs ... ))
290+ return set_message (bpc, edge, updated_message (message_update_alg, bpc, edge))
259291end
260292
261293"""
262294Do a sequential update of the message tensors on `edges`
263295"""
264- function update (
265- alg:: Algorithm ,
296+ function update_iteration (
297+ alg:: Algorithm"bp" ,
266298 bpc:: AbstractBeliefPropagationCache ,
267299 edges:: Vector ;
268300 (update_diff!)= nothing ,
269- kwargs... ,
270301)
271302 bpc = copy (bpc)
272303 for e in edges
273304 prev_message = ! isnothing (update_diff!) ? message (bpc, e) : nothing
274- bpc = update (alg, bpc, e; kwargs ... )
305+ bpc = update_message (alg. kwargs . message_update_alg , bpc, e)
275306 if ! isnothing (update_diff!)
276307 update_diff![] += message_diff (message (bpc, e), prev_message)
277308 end
@@ -284,15 +315,15 @@ Do parallel updates between groups of edges of all message tensors
284315Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
285316mts relevant to that group.
286317"""
287- function update (
288- alg:: Algorithm ,
318+ function update_iteration (
319+ alg:: Algorithm"bp" ,
289320 bpc:: AbstractBeliefPropagationCache ,
290321 edge_groups:: Vector{<:Vector{<:PartitionEdge}} ;
291- kwargs ... ,
322+ (update_diff!) = nothing ,
292323)
293324 new_mts = empty (messages (bpc))
294325 for edges in edge_groups
295- bpc_t = update (alg, bpc, edges; kwargs ... )
326+ bpc_t = update_iteration (alg. kwargs . message_update_alg , bpc, edges; (update_diff!) )
296327 for e in edges
297328 set! (new_mts, e, message (bpc_t, e))
298329 end
@@ -303,24 +334,16 @@ end
303334"""
304335More generic interface for update, with default params
305336"""
306- function update (
307- alg:: Algorithm ,
308- bpc:: AbstractBeliefPropagationCache ;
309- edges= default_edge_sequence (alg, bpc),
310- maxiter= default_bp_maxiter (alg, bpc),
311- message_update_kwargs= default_message_update_kwargs (alg, bpc),
312- tol= nothing ,
313- verbose= false ,
314- )
315- compute_error = ! isnothing (tol)
316- if isnothing (maxiter)
337+ function update (alg:: Algorithm"bp" , bpc:: AbstractBeliefPropagationCache )
338+ compute_error = ! isnothing (alg. kwargs. tol)
339+ if isnothing (alg. kwargs. maxiter)
317340 error (" You need to specify a number of iterations for BP!" )
318341 end
319- for i in 1 : maxiter
342+ for i in 1 : alg . kwargs . maxiter
320343 diff = compute_error ? Ref (0.0 ) : nothing
321- bpc = update (alg, bpc, edges ; (update_diff!)= diff, message_update_kwargs ... )
322- if compute_error && (diff. x / length (edges )) <= tol
323- if verbose
344+ bpc = update_iteration (alg, bpc, alg . kwargs . edge_sequence ; (update_diff!)= diff)
345+ if compute_error && (diff. x / length (alg . kwargs . edge_sequence )) <= alg . kwargs . tol
346+ if alg . kwargs . verbose
324347 println (" BP converged to desired precision after $i iterations." )
325348 end
326349 break
@@ -329,12 +352,8 @@ function update(
329352 return bpc
330353end
331354
332- function update (
333- bpc:: AbstractBeliefPropagationCache ;
334- alg:: String = default_message_update_alg (bpc),
335- kwargs... ,
336- )
337- return update (Algorithm (alg), bpc; kwargs... )
355+ function update (bpc:: AbstractBeliefPropagationCache ; alg= default_update_alg (bpc), kwargs... )
356+ return update (set_default_kwargs (Algorithm (alg; kwargs... ), bpc), bpc)
338357end
339358
340359function rescale_messages (
0 commit comments