Skip to content

Commit ccf8384

Browse files
committed
Make AbstractBPCache an AbstractITensorNetwork
1 parent 59ef115 commit ccf8384

File tree

4 files changed

+78
-24
lines changed

4 files changed

+78
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.13.7"
4+
version = "0.13.8"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,18 @@ using NamedGraphs.PartitionedGraphs:
1111
partitionedges,
1212
unpartitioned_graph
1313
using SimpleTraits: SimpleTraits, Not, @traitfn
14+
using NamedGraphs.SimilarType: SimilarType
1415
using NDTensors: NDTensors
1516

16-
abstract type AbstractBeliefPropagationCache end
17+
abstract type AbstractBeliefPropagationCache{V,PV} <: AbstractITensorNetwork{V} end
18+
19+
function SimilarType.similar_type(bpc::AbstractBeliefPropagationCache)
20+
return typeof(tensornetwork(bpc))
21+
end
22+
function data_graph_type(bpc::AbstractBeliefPropagationCache)
23+
return data_graph_type(tensornetwork(bpc))
24+
end
25+
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))
1726

1827
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
1928
sequence = contraction_sequence(contract_list; alg="optimal")
@@ -40,6 +49,9 @@ default_messages(ptn::PartitionedGraph) = Dictionary()
4049
end
4150
default_partitioned_vertices::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
4251

52+
function Base.setindex!(bpc::AbstractBeliefPropagationCache, factor::ITensor, vertex)
53+
not_implemented()
54+
end
4355
partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
4456
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
4557
function default_message(
@@ -88,12 +100,8 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
88100
return unpartitioned_graph(partitioned_tensornetwork(bpc))
89101
end
90102

91-
function setindex_preserve_graph!(bpc::AbstractBeliefPropagationCache, args...)
92-
return setindex_preserve_graph!(tensornetwork(bpc), args...)
93-
end
94-
95103
function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
96-
return ITensor[tensornetwork(bpc)[v] for v in verts]
104+
return ITensor[copy(bpc[v]) for v in verts]
97105
end
98106

99107
function factors(
@@ -143,7 +151,6 @@ for f in [
143151
:(PartitionedGraphs.partitionvertices),
144152
:(PartitionedGraphs.vertices),
145153
:(PartitionedGraphs.boundary_partitionedges),
146-
:(linkinds),
147154
]
148155
@eval begin
149156
function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
@@ -152,23 +159,28 @@ for f in [
152159
end
153160
end
154161

162+
function linkinds(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
163+
return linkinds(partitioned_tensornetwork(bpc), pe)
164+
end
165+
155166
NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc))
156167

157168
"""
158-
Update the tensornetwork inside the cache
169+
Update the tensornetwork inside the cache out-of-place
159170
"""
160171
function update_factors(bpc::AbstractBeliefPropagationCache, factors)
161172
bpc = copy(bpc)
162-
tn = tensornetwork(bpc)
163173
for vertex in eachindex(factors)
164174
# TODO: Add a check that this preserves the graph structure.
165-
setindex_preserve_graph!(tn, factors[vertex], vertex)
175+
setindex_preserve_graph!(bpc, factors[vertex], vertex)
166176
end
167177
return bpc
168178
end
169179

170180
function update_factor(bpc, vertex, factor)
171-
return update_factors(bpc, Dictionary([vertex], [factor]))
181+
bpc = copy(bpc)
182+
setindex_preserve_graph!(bpc, factor, vertex)
183+
return bpc
172184
end
173185

174186
function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...)
@@ -178,12 +190,45 @@ end
178190
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...)
179191
return map(edge -> message(bpc, edge; kwargs...), edges)
180192
end
193+
function set_messages!(bpc::AbstractBeliefPropagationCache, partitionedges_messages)
194+
ms = messages(bpc)
195+
for pe in eachindex(partitionedges_messages)
196+
# TODO: Add a check that this preserves the graph structure.
197+
set!(ms, pe, partitionedges_messages[pe])
198+
end
199+
return bpc
200+
end
201+
function set_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
202+
ms = messages(bpc)
203+
set!(ms, pe, message)
204+
return bpc
205+
end
206+
207+
function set_messages(bpc::AbstractBeliefPropagationCache, partitionedges_messages)
208+
bpc = copy(bpc)
209+
return set_messages!(bpc, partitionedges_messages)
210+
end
181211
function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
182212
bpc = copy(bpc)
213+
return set_message!(bpc, pe, message)
214+
end
215+
function delete_messages!(bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge})
183216
ms = messages(bpc)
184-
set!(ms, pe, message)
217+
for pe in pes
218+
delete!(ms, pe)
219+
end
185220
return bpc
186221
end
222+
function delete_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
223+
return delete_message!(bpc, [pe])
224+
end
225+
function delete_messages(bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge})
226+
bpc = copy(bpc)
227+
return delete_messages!(bpc, pes)
228+
end
229+
function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
230+
return delete_message(bpc, [pe])
231+
end
187232

188233
"""
189234
Compute message tensor as product of incoming mts and local state
@@ -241,11 +286,11 @@ function update(
241286
edge_groups::Vector{<:Vector{<:PartitionEdge}};
242287
kwargs...,
243288
)
244-
new_mts = copy(messages(bpc))
289+
new_mts = Dictionary()
245290
for edges in edge_groups
246291
bpc_t = update(alg, bpc, edges; kwargs...)
247292
for e in edges
248-
new_mts[e] = message(bpc_t, e)
293+
set!(new_mts, e, message(bpc_t, e))
249294
end
250295
end
251296
return set_messages(bpc, new_mts)
@@ -288,10 +333,6 @@ function update(
288333
return update(Algorithm(alg), bpc; kwargs...)
289334
end
290335

291-
function scale!(bp_cache::AbstractBeliefPropagationCache, args...)
292-
return scale!(tensornetwork(bp_cache), args...)
293-
end
294-
295336
function rescale_messages(
296337
bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge
297338
)

src/caches/beliefpropagationcache.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using SplitApplyCombine: group
33
using LinearAlgebra: diag, dot
44
using ITensors: dir
55
using NamedGraphs.PartitionedGraphs:
6+
AbstractPartitionedGraph,
67
PartitionedGraphs,
78
PartitionedGraph,
89
PartitionVertex,
@@ -23,7 +24,8 @@ function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGr
2324
return (;)
2425
end
2526

26-
struct BeliefPropagationCache{PTN,MTS} <: AbstractBeliefPropagationCache
27+
struct BeliefPropagationCache{V,PV,PTN<:AbstractPartitionedGraph{V,PV},MTS} <:
28+
AbstractBeliefPropagationCache{V,PV}
2729
partitioned_tensornetwork::PTN
2830
messages::MTS
2931
end
@@ -81,15 +83,12 @@ function default_message_update_kwargs(
8183
return (;)
8284
end
8385

86+
Base.setindex!(bpc::BeliefPropagationCache, factor::ITensor, vertex) = not_implemented()
8487
partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc))
8588
function PartitionedGraphs.partitionedges(bpc::BeliefPropagationCache)
8689
partitionedges(partitioned_tensornetwork(bpc))
8790
end
8891

89-
function set_messages(cache::BeliefPropagationCache, messages)
90-
return BeliefPropagationCache(partitioned_tensornetwork(cache), messages)
91-
end
92-
9392
function environment(bpc::BeliefPropagationCache, verts::Vector; kwargs...)
9493
partition_verts = partitionvertices(bpc, verts)
9594
messages = incoming_messages(bpc, partition_verts; kwargs...)

test/test_belief_propagation.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ITensorNetworks:
66
ITensorNetworks,
77
BeliefPropagationCache,
88
,
9+
@preserve_graph,
910
combine_linkinds,
1011
contract,
1112
contraction_sequence,
@@ -48,6 +49,19 @@ using Test: @test, @testset
4849
ψ = random_tensornetwork(rng, elt, s; link_space=χ)
4950
ψψ = ψ prime(dag(ψ); sites=[])
5051
bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ)))
52+
53+
#Test updating the tensors in the cache
54+
vket, vbra = ((1, 1), 1), ((1, 1), 2)
55+
A = bpc[vket]
56+
new_A = random_itensor(elt, inds(A))
57+
new_A_dag = ITensors.replaceind(
58+
dag(prime(new_A)), only(s[first(vket)])', only(s[first(vket)])
59+
)
60+
@preserve_graph bpc[vket] = new_A
61+
@preserve_graph bpc[vbra] = new_A_dag
62+
@test bpc[vket] == new_A
63+
@test bpc[vbra] == new_A_dag
64+
5165
bpc = update(bpc; maxiter=25, tol=eps(real(elt)))
5266
#Test messages are converged
5367
for pe in partitionedges(bpc)

0 commit comments

Comments
 (0)