1
+ using Adapt: Adapt, adapt, adapt_structure
1
2
using Graphs: Graphs, IsDirected
2
3
using SplitApplyCombine: group
3
4
using LinearAlgebra: diag, dot
@@ -24,24 +25,20 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
24
25
end
25
26
data_graph (bpc:: AbstractBeliefPropagationCache ) = data_graph (tensornetwork (bpc))
26
27
27
- function default_message_update (contract_list:: Vector{ITensor} ; normalize= true , kwargs... )
28
- sequence = contraction_sequence (contract_list; alg= " optimal" )
29
- updated_messages = contract (contract_list; sequence, kwargs... )
30
- message_norm = norm (updated_messages)
31
- if normalize && ! iszero (message_norm)
32
- updated_messages /= message_norm
33
- end
34
- return ITensor[updated_messages]
35
- end
36
-
37
28
# TODO : Take `dot` without precontracting the messages to allow scaling to more complex messages
38
29
function message_diff (message_a:: Vector{ITensor} , message_b:: Vector{ITensor} )
39
30
lhs, rhs = contract (message_a), contract (message_b)
40
31
f = abs2 (dot (lhs / norm (lhs), rhs / norm (rhs)))
41
32
return 1 - f
42
33
end
43
34
44
- default_message (elt, inds_e) = ITensor[denseblocks (delta (elt, i)) for i in inds_e]
35
+ function default_message (datatype:: Type{<:AbstractArray} , inds_e)
36
+ return [adapt (datatype, denseblocks (delta (i))) for i in inds_e]
37
+ end
38
+
39
+ function default_message (elt:: Type{<:Number} , inds_e)
40
+ return default_message (Vector{elt}, inds_e)
41
+ end
45
42
default_messages (ptn:: PartitionedGraph ) = Dictionary ()
46
43
@traitfn default_bp_maxiter (g: :: :(! IsDirected)) = is_tree (g) ? 1 : nothing
47
44
@traitfn function default_bp_maxiter (g: :: :IsDirected )
@@ -59,15 +56,13 @@ function default_message(
59
56
)
60
57
return not_implemented ()
61
58
end
59
+ default_update_alg (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
62
60
default_message_update_alg (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
63
61
Base. copy (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
64
62
default_bp_maxiter (alg:: Algorithm , bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
65
63
function default_edge_sequence (alg:: Algorithm , bpc:: AbstractBeliefPropagationCache )
66
64
return not_implemented ()
67
65
end
68
- function default_message_update_kwargs (alg:: Algorithm , bpc:: AbstractBeliefPropagationCache )
69
- return not_implemented ()
70
- end
71
66
function environment (bpc:: AbstractBeliefPropagationCache , verts:: Vector ; kwargs... )
72
67
return not_implemented ()
73
68
end
80
75
partitions (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
81
76
PartitionedGraphs. partitionedges (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
82
77
83
- function default_edge_sequence (
84
- bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
85
- )
86
- return default_edge_sequence (Algorithm (alg), bpc)
87
- end
88
- function default_bp_maxiter (
89
- bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
90
- )
91
- return default_bp_maxiter (Algorithm (alg), bpc)
92
- end
93
- function default_message_update_kwargs (
94
- bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
95
- )
96
- return default_message_update_kwargs (Algorithm (alg), bpc)
97
- end
78
+ default_bp_edge_sequence (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
79
+ default_bp_maxiter (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
98
80
99
81
function tensornetwork (bpc:: AbstractBeliefPropagationCache )
100
82
return unpartitioned_graph (partitioned_tensornetwork (bpc))
@@ -144,6 +126,36 @@ function incoming_messages(
144
126
return incoming_messages (bpc, [partition_vertex]; kwargs... )
145
127
end
146
128
129
+ # Adapt interface for changing device
130
+ function map_messages (
131
+ f, bpc:: AbstractBeliefPropagationCache , pes= collect (keys (messages (bpc)))
132
+ )
133
+ bpc = copy (bpc)
134
+ for pe in pes
135
+ set_message! (bpc, pe, f .(message (bpc, pe)))
136
+ end
137
+ return bpc
138
+ end
139
+ function map_factors (f, bpc:: AbstractBeliefPropagationCache , vs= vertices (bpc))
140
+ bpc = copy (bpc)
141
+ for v in vs
142
+ @preserve_graph bpc[v] = f (bpc[v])
143
+ end
144
+ return bpc
145
+ end
146
+ function adapt_messages (to, bpc:: AbstractBeliefPropagationCache , args... )
147
+ return map_messages (adapt (to), bpc, args... )
148
+ end
149
+ function adapt_factors (to, bpc:: AbstractBeliefPropagationCache , args... )
150
+ return map_factors (adapt (to), bpc, args... )
151
+ end
152
+
153
+ function Adapt. adapt_structure (to, bpc:: AbstractBeliefPropagationCache )
154
+ bpc = adapt_messages (to, bpc)
155
+ bpc = adapt_factors (to, bpc)
156
+ return bpc
157
+ end
158
+
147
159
# Forward from partitioned graph
148
160
for f in [
149
161
:(PartitionedGraphs. partitioned_graph),
@@ -234,44 +246,63 @@ function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
234
246
return delete_messages (bpc, [pe])
235
247
end
236
248
237
- """
238
- Compute message tensor as product of incoming mts and local state
239
- """
240
249
function updated_message (
241
- bpc:: AbstractBeliefPropagationCache ,
242
- edge:: PartitionEdge ;
243
- message_update_function= default_message_update,
244
- message_update_function_kwargs= (;),
250
+ alg:: Algorithm"contract" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
245
251
)
246
252
vertex = src (edge)
247
253
incoming_ms = incoming_messages (bpc, vertex; ignore_edges= PartitionEdge[reverse (edge)])
248
254
state = factors (bpc, vertex)
255
+ contract_list = ITensor[incoming_ms; state]
256
+ sequence = contraction_sequence (contract_list; alg= alg. kwargs. sequence_alg)
257
+ updated_messages = contract (contract_list; sequence)
258
+ message_norm = norm (updated_messages)
259
+ if alg. kwargs. normalize && ! iszero (message_norm)
260
+ updated_messages /= message_norm
261
+ end
262
+ return ITensor[updated_messages]
263
+ end
249
264
250
- return message_update_function (
251
- ITensor[incoming_ms; state]; message_update_function_kwargs...
265
+ function updated_message (
266
+ alg:: Algorithm"adapt_update" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
267
+ )
268
+ incoming_pes = setdiff (
269
+ boundary_partitionedges (bpc, [src (edge)]; dir= :in ), [reverse (edge)]
252
270
)
271
+ adapted_bpc = adapt_messages (alg. kwargs. adapt, bpc, incoming_pes)
272
+ adapted_bpc = adapt_factors (alg. kwargs. adapt, bpc, vertices (bpc, src (edge)))
273
+ updated_messages = updated_message (alg. kwargs. alg, adapted_bpc, edge)
274
+ dtype = mapreduce (datatype, promote_type, message (bpc, edge))
275
+ return map (adapt (dtype), updated_messages)
276
+ end
277
+
278
+ function updated_message (
279
+ bpc:: AbstractBeliefPropagationCache ,
280
+ edge:: PartitionEdge ;
281
+ alg= default_message_update_alg (bpc),
282
+ kwargs... ,
283
+ )
284
+ return updated_message (set_default_kwargs (Algorithm (alg; kwargs... )), bpc, edge)
253
285
end
254
286
255
- function update (
256
- alg :: Algorithm"bp" , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge ; kwargs ...
287
+ function update_message (
288
+ message_update_alg :: Algorithm , bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge
257
289
)
258
- return set_message (bpc, edge, updated_message (bpc, edge; kwargs ... ))
290
+ return set_message (bpc, edge, updated_message (message_update_alg, bpc, edge))
259
291
end
260
292
261
293
"""
262
294
Do a sequential update of the message tensors on `edges`
263
295
"""
264
- function update (
265
- alg:: Algorithm ,
296
+ function update_iteration (
297
+ alg:: Algorithm"bp" ,
266
298
bpc:: AbstractBeliefPropagationCache ,
267
299
edges:: Vector ;
268
300
(update_diff!)= nothing ,
269
- kwargs... ,
270
301
)
271
302
bpc = copy (bpc)
272
303
for e in edges
273
304
prev_message = ! isnothing (update_diff!) ? message (bpc, e) : nothing
274
- bpc = update (alg, bpc, e; kwargs ... )
305
+ bpc = update_message (alg. kwargs . message_update_alg , bpc, e)
275
306
if ! isnothing (update_diff!)
276
307
update_diff![] += message_diff (message (bpc, e), prev_message)
277
308
end
@@ -284,15 +315,15 @@ Do parallel updates between groups of edges of all message tensors
284
315
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
285
316
mts relevant to that group.
286
317
"""
287
- function update (
288
- alg:: Algorithm ,
318
+ function update_iteration (
319
+ alg:: Algorithm"bp" ,
289
320
bpc:: AbstractBeliefPropagationCache ,
290
321
edge_groups:: Vector{<:Vector{<:PartitionEdge}} ;
291
- kwargs ... ,
322
+ (update_diff!) = nothing ,
292
323
)
293
324
new_mts = empty (messages (bpc))
294
325
for edges in edge_groups
295
- bpc_t = update (alg, bpc, edges; kwargs ... )
326
+ bpc_t = update_iteration (alg. kwargs . message_update_alg , bpc, edges; (update_diff!) )
296
327
for e in edges
297
328
set! (new_mts, e, message (bpc_t, e))
298
329
end
@@ -303,24 +334,16 @@ end
303
334
"""
304
335
More generic interface for update, with default params
305
336
"""
306
- function update (
307
- alg:: Algorithm ,
308
- bpc:: AbstractBeliefPropagationCache ;
309
- edges= default_edge_sequence (alg, bpc),
310
- maxiter= default_bp_maxiter (alg, bpc),
311
- message_update_kwargs= default_message_update_kwargs (alg, bpc),
312
- tol= nothing ,
313
- verbose= false ,
314
- )
315
- compute_error = ! isnothing (tol)
316
- if isnothing (maxiter)
337
+ function update (alg:: Algorithm"bp" , bpc:: AbstractBeliefPropagationCache )
338
+ compute_error = ! isnothing (alg. kwargs. tol)
339
+ if isnothing (alg. kwargs. maxiter)
317
340
error (" You need to specify a number of iterations for BP!" )
318
341
end
319
- for i in 1 : maxiter
342
+ for i in 1 : alg . kwargs . maxiter
320
343
diff = compute_error ? Ref (0.0 ) : nothing
321
- bpc = update (alg, bpc, edges ; (update_diff!)= diff, message_update_kwargs ... )
322
- if compute_error && (diff. x / length (edges )) <= tol
323
- if verbose
344
+ bpc = update_iteration (alg, bpc, alg . kwargs . edge_sequence ; (update_diff!)= diff)
345
+ if compute_error && (diff. x / length (alg . kwargs . edge_sequence )) <= alg . kwargs . tol
346
+ if alg . kwargs . verbose
324
347
println (" BP converged to desired precision after $i iterations." )
325
348
end
326
349
break
@@ -329,12 +352,8 @@ function update(
329
352
return bpc
330
353
end
331
354
332
- function update (
333
- bpc:: AbstractBeliefPropagationCache ;
334
- alg:: String = default_message_update_alg (bpc),
335
- kwargs... ,
336
- )
337
- return update (Algorithm (alg), bpc; kwargs... )
355
+ function update (bpc:: AbstractBeliefPropagationCache ; alg= default_update_alg (bpc), kwargs... )
356
+ return update (set_default_kwargs (Algorithm (alg; kwargs... ), bpc), bpc)
338
357
end
339
358
340
359
function rescale_messages (
0 commit comments