@@ -11,9 +11,18 @@ using NamedGraphs.PartitionedGraphs:
1111 partitionedges,
1212 unpartitioned_graph
1313using SimpleTraits: SimpleTraits, Not, @traitfn
14+ using NamedGraphs. SimilarType: SimilarType
1415using NDTensors: NDTensors
1516
16- abstract type AbstractBeliefPropagationCache end
17+ abstract type AbstractBeliefPropagationCache{V,PV} <: AbstractITensorNetwork{V} end
18+
19+ function SimilarType. similar_type (bpc:: AbstractBeliefPropagationCache )
20+ return typeof (tensornetwork (bpc))
21+ end
22+ function data_graph_type (bpc:: AbstractBeliefPropagationCache )
23+ return data_graph_type (tensornetwork (bpc))
24+ end
25+ data_graph (bpc:: AbstractBeliefPropagationCache ) = data_graph (tensornetwork (bpc))
1726
1827function default_message_update (contract_list:: Vector{ITensor} ; normalize= true , kwargs... )
1928 sequence = contraction_sequence (contract_list; alg= " optimal" )
@@ -40,6 +49,9 @@ default_messages(ptn::PartitionedGraph) = Dictionary()
4049end
4150default_partitioned_vertices (ψ:: AbstractITensorNetwork ) = group (v -> v, vertices (ψ))
4251
52+ function Base. setindex! (bpc:: AbstractBeliefPropagationCache , factor:: ITensor , vertex)
53+ not_implemented ()
54+ end
4355partitioned_tensornetwork (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
4456messages (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
4557function default_message (
@@ -88,12 +100,8 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
88100 return unpartitioned_graph (partitioned_tensornetwork (bpc))
89101end
90102
91- function setindex_preserve_graph! (bpc:: AbstractBeliefPropagationCache , args... )
92- return setindex_preserve_graph! (tensornetwork (bpc), args... )
93- end
94-
95103function factors (bpc:: AbstractBeliefPropagationCache , verts:: Vector )
96- return ITensor[tensornetwork (bpc) [v] for v in verts]
104+ return ITensor[copy (bpc[v]) for v in verts]
97105end
98106
99107function factors (
@@ -143,7 +151,6 @@ for f in [
143151 :(PartitionedGraphs. partitionvertices),
144152 :(PartitionedGraphs. vertices),
145153 :(PartitionedGraphs. boundary_partitionedges),
146- :(linkinds),
147154]
148155 @eval begin
149156 function $f (bpc:: AbstractBeliefPropagationCache , args... ; kwargs... )
@@ -152,23 +159,28 @@ for f in [
152159 end
153160end
154161
162+ function linkinds (bpc:: AbstractBeliefPropagationCache , pe:: PartitionEdge )
163+ return linkinds (partitioned_tensornetwork (bpc), pe)
164+ end
165+
155166NDTensors. scalartype (bpc:: AbstractBeliefPropagationCache ) = scalartype (tensornetwork (bpc))
156167
157168"""
158- Update the tensornetwork inside the cache
169+ Update the tensornetwork inside the cache out-of-place
159170"""
160171function update_factors (bpc:: AbstractBeliefPropagationCache , factors)
161172 bpc = copy (bpc)
162- tn = tensornetwork (bpc)
163173 for vertex in eachindex (factors)
164174 # TODO : Add a check that this preserves the graph structure.
165- setindex_preserve_graph! (tn , factors[vertex], vertex)
175+ setindex_preserve_graph! (bpc , factors[vertex], vertex)
166176 end
167177 return bpc
168178end
169179
170180function update_factor (bpc, vertex, factor)
171- return update_factors (bpc, Dictionary ([vertex], [factor]))
181+ bpc = copy (bpc)
182+ setindex_preserve_graph! (bpc, factor, vertex)
183+ return bpc
172184end
173185
174186function message (bpc:: AbstractBeliefPropagationCache , edge:: PartitionEdge ; kwargs... )
@@ -178,12 +190,45 @@ end
178190function messages (bpc:: AbstractBeliefPropagationCache , edges; kwargs... )
179191 return map (edge -> message (bpc, edge; kwargs... ), edges)
180192end
193+ function set_messages! (bpc:: AbstractBeliefPropagationCache , partitionedges_messages)
194+ ms = messages (bpc)
195+ for pe in eachindex (partitionedges_messages)
196+ # TODO : Add a check that this preserves the graph structure.
197+ set! (ms, pe, partitionedges_messages[pe])
198+ end
199+ return bpc
200+ end
201+ function set_message! (bpc:: AbstractBeliefPropagationCache , pe:: PartitionEdge , message)
202+ ms = messages (bpc)
203+ set! (ms, pe, message)
204+ return bpc
205+ end
206+
207+ function set_messages (bpc:: AbstractBeliefPropagationCache , partitionedges_messages)
208+ bpc = copy (bpc)
209+ return set_messages! (bpc, partitionedges_messages)
210+ end
181211function set_message (bpc:: AbstractBeliefPropagationCache , pe:: PartitionEdge , message)
182212 bpc = copy (bpc)
213+ return set_message! (bpc, pe, message)
214+ end
215+ function delete_messages! (bpc:: AbstractBeliefPropagationCache , pes:: Vector{<:PartitionEdge} )
183216 ms = messages (bpc)
184- set! (ms, pe, message)
217+ for pe in pes
218+ delete! (ms, pe)
219+ end
185220 return bpc
186221end
222+ function delete_message! (bpc:: AbstractBeliefPropagationCache , pe:: PartitionEdge )
223+ return delete_message! (bpc, [pe])
224+ end
225+ function delete_messages (bpc:: AbstractBeliefPropagationCache , pes:: Vector{<:PartitionEdge} )
226+ bpc = copy (bpc)
227+ return delete_messages! (bpc, pes)
228+ end
229+ function delete_message (bpc:: AbstractBeliefPropagationCache , pe:: PartitionEdge )
230+ return delete_message (bpc, [pe])
231+ end
187232
188233"""
189234Compute message tensor as product of incoming mts and local state
@@ -241,11 +286,11 @@ function update(
241286 edge_groups:: Vector{<:Vector{<:PartitionEdge}} ;
242287 kwargs... ,
243288)
244- new_mts = copy ( messages (bpc) )
289+ new_mts = Dictionary ( )
245290 for edges in edge_groups
246291 bpc_t = update (alg, bpc, edges; kwargs... )
247292 for e in edges
248- new_mts[e] = message (bpc_t, e)
293+ set! ( new_mts, e, message (bpc_t, e) )
249294 end
250295 end
251296 return set_messages (bpc, new_mts)
@@ -288,10 +333,6 @@ function update(
288333 return update (Algorithm (alg), bpc; kwargs... )
289334end
290335
291- function scale! (bp_cache:: AbstractBeliefPropagationCache , args... )
292- return scale! (tensornetwork (bp_cache), args... )
293- end
294-
295336function rescale_messages (
296337 bp_cache:: AbstractBeliefPropagationCache , partitionedge:: PartitionEdge
297338)
0 commit comments