@@ -25,34 +25,21 @@ function data_graph_type(bpc::AbstractBeliefPropagationCache)
2525end
2626data_graph (bpc:: AbstractBeliefPropagationCache ) = data_graph (tensornetwork (bpc))
2727
28- function message_update_function (
29- alg:: Algorithm"contract" ,
30- contract_list:: Vector{ITensor} ;
31- normalize= alg. kwargs. normalize,
32- sequence_alg= alg. kwargs. sequence_alg,
33- )
34- sequence = contraction_sequence (contract_list; alg= sequence_alg)
28+ function message_update (alg:: Algorithm"contract" , contract_list:: Vector{ITensor} ;)
29+ sequence = contraction_sequence (contract_list; alg= alg. kwargs. sequence_alg)
3530 updated_messages = contract (contract_list; sequence)
3631 message_norm = norm (updated_messages)
37- if normalize && ! iszero (message_norm)
32+ if alg . kwargs . normalize && ! iszero (message_norm)
3833 updated_messages /= message_norm
3934 end
4035 return ITensor[updated_messages]
4136end
4237
43- function message_update_function (
44- alg:: Algorithm"contract_custom_device" ,
45- contract_list:: Vector{ITensor} ;
46- normalize= alg. kwargs. normalize,
47- sequence_alg= alg. kwargs. sequence_alg,
48- custom_device_adapt= alg. kwargs. adapt,
49- )
50- adapted_contract_list = custom_device_adapt .(contract_list)
51- updated_messages = message_update_function (
52- Algorithm (" contract" ), adapted_contract_list; normalize, sequence_alg
53- )
38+ function message_update (alg:: Algorithm"adapt_update" , contract_list:: Vector{ITensor} ;)
39+ adapted_contract_list = alg. kwargs. adapt .(contract_list)
40+ updated_messages = message_update (alg. kwargs. alg, adapted_contract_list)
5441 dtype = datatype (first (contract_list))
55- return ITensor[ adapt (dtype, updated_message) for updated_message in updated_messages]
42+ return map ( adapt (dtype), updated_messages)
5643end
5744
5845# TODO : Take `dot` without precontracting the messages to allow scaling to more complex messages
@@ -62,8 +49,12 @@ function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
6249 return 1 - f
6350end
6451
65- function default_message (datatype, elt, inds_e)
66- ITensor[adapt (datatype, denseblocks (delta (elt, i))) for i in inds_e]
52+ function default_message (datatype:: Type{<:AbstractArray} , inds_e)
53+ return [adapt (datatype, denseblocks (delta (i))) for i in inds_e]
54+ end
55+
56+ function default_message (elt:: Type{<:Number} , inds_e)
57+ return default_message (Vector{elt}, inds_e)
6758end
6859default_messages (ptn:: PartitionedGraph ) = Dictionary ()
6960@traitfn default_bp_maxiter (g: :: :(! IsDirected)) = is_tree (g) ? 1 : 30
@@ -153,20 +144,22 @@ function incoming_messages(
153144end
154145
155146# Adapt interface for changing device
156- function adapt_messages (to , bpc:: AbstractBeliefPropagationCache )
147+ function map_messages (map , bpc:: AbstractBeliefPropagationCache )
157148 bpc = copy (bpc)
158149 for pe in keys (messages (bpc))
159- set_message! (bpc, pe, adapt (to) .(message (bpc, pe)))
150+ set_message! (bpc, pe, map .(message (bpc, pe)))
160151 end
161152 return bpc
162153end
163- function adapt_factors (to, bpc:: AbstractBeliefPropagationCache )
154+ function map_factors (to, bpc:: AbstractBeliefPropagationCache )
164155 bpc = copy (bpc)
165156 for v in vertices (bpc)
166- @preserve_graph bpc[v] = adapt (to). (bpc[v])
157+ @preserve_graph bpc[v] = map (bpc[v])
167158 end
168159 return bpc
169160end
161+ adapt_messages (to, bpc:: AbstractBeliefPropagationCache ) = map_messages (adapt (to), bpc)
162+ adapt_factors (to, bpc:: AbstractBeliefPropagationCache ) = map_factors (adapt (to), bpc)
170163
171164function Adapt. adapt_structure (to, bpc:: AbstractBeliefPropagationCache )
172165 bpc = adapt_messages (to, bpc)
@@ -277,7 +270,7 @@ function updated_message(
277270 incoming_ms = incoming_messages (bpc, vertex; ignore_edges= PartitionEdge[reverse (edge)])
278271 state = factors (bpc, vertex)
279272
280- return message_update_function (message_update_alg, ITensor[incoming_ms; state]; kwargs... )
273+ return message_update (message_update_alg, ITensor[incoming_ms; state]; kwargs... )
281274end
282275
283276function update (
0 commit comments