Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.13.7"
version = "0.13.8"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
77 changes: 59 additions & 18 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@ using NamedGraphs.PartitionedGraphs:
partitionedges,
unpartitioned_graph
using SimpleTraits: SimpleTraits, Not, @traitfn
using NamedGraphs.SimilarType: SimilarType
using NDTensors: NDTensors

abstract type AbstractBeliefPropagationCache end
abstract type AbstractBeliefPropagationCache{V,PV} <: AbstractITensorNetwork{V} end

function SimilarType.similar_type(bpc::AbstractBeliefPropagationCache)
return typeof(tensornetwork(bpc))
end
function data_graph_type(bpc::AbstractBeliefPropagationCache)
return data_graph_type(tensornetwork(bpc))
end
data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc))

function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
sequence = contraction_sequence(contract_list; alg="optimal")
Expand All @@ -40,6 +49,9 @@ default_messages(ptn::PartitionedGraph) = Dictionary()
end
default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ))

function Base.setindex!(bpc::AbstractBeliefPropagationCache, factor::ITensor, vertex)
return not_implemented()
end
partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_message(
Expand Down Expand Up @@ -88,12 +100,8 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
return unpartitioned_graph(partitioned_tensornetwork(bpc))
end

function setindex_preserve_graph!(bpc::AbstractBeliefPropagationCache, args...)
return setindex_preserve_graph!(tensornetwork(bpc), args...)
end

function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
return ITensor[tensornetwork(bpc)[v] for v in verts]
return ITensor[bpc[v] for v in verts]
end

function factors(
Expand Down Expand Up @@ -143,7 +151,6 @@ for f in [
:(PartitionedGraphs.partitionvertices),
:(PartitionedGraphs.vertices),
:(PartitionedGraphs.boundary_partitionedges),
:(linkinds),
]
@eval begin
function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
Expand All @@ -152,23 +159,28 @@ for f in [
end
end

function linkinds(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
return linkinds(partitioned_tensornetwork(bpc), pe)
end

NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc))

"""
Update the tensornetwork inside the cache
Update the tensornetwork inside the cache out-of-place
"""
function update_factors(bpc::AbstractBeliefPropagationCache, factors)
bpc = copy(bpc)
tn = tensornetwork(bpc)
for vertex in eachindex(factors)
# TODO: Add a check that this preserves the graph structure.
setindex_preserve_graph!(tn, factors[vertex], vertex)
setindex_preserve_graph!(bpc, factors[vertex], vertex)
end
return bpc
end

function update_factor(bpc, vertex, factor)
return update_factors(bpc, Dictionary([vertex], [factor]))
bpc = copy(bpc)
setindex_preserve_graph!(bpc, factor, vertex)
return bpc
end

function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...)
Expand All @@ -178,12 +190,45 @@ end
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...)
return map(edge -> message(bpc, edge; kwargs...), edges)
end
function set_messages!(bpc::AbstractBeliefPropagationCache, partitionedges_messages)
ms = messages(bpc)
for pe in eachindex(partitionedges_messages)
# TODO: Add a check that this preserves the graph structure.
set!(ms, pe, partitionedges_messages[pe])
end
return bpc
end
function set_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
ms = messages(bpc)
set!(ms, pe, message)
return bpc
end

function set_messages(bpc::AbstractBeliefPropagationCache, partitionedges_messages)
bpc = copy(bpc)
return set_messages!(bpc, partitionedges_messages)
end
function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
bpc = copy(bpc)
return set_message!(bpc, pe, message)
end
function delete_messages!(bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge})
ms = messages(bpc)
set!(ms, pe, message)
for pe in pes
delete!(ms, pe)
end
return bpc
end
function delete_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
return delete_message!(bpc, [pe])
end
function delete_messages(bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge})
bpc = copy(bpc)
return delete_messages!(bpc, pes)
end
function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge)
return delete_message(bpc, [pe])
end

"""
Compute message tensor as product of incoming mts and local state
Expand Down Expand Up @@ -241,11 +286,11 @@ function update(
edge_groups::Vector{<:Vector{<:PartitionEdge}};
kwargs...,
)
new_mts = copy(messages(bpc))
new_mts = empty(messages(bpc))
for edges in edge_groups
bpc_t = update(alg, bpc, edges; kwargs...)
for e in edges
new_mts[e] = message(bpc_t, e)
set!(new_mts, e, message(bpc_t, e))
end
end
return set_messages(bpc, new_mts)
Expand Down Expand Up @@ -288,10 +333,6 @@ function update(
return update(Algorithm(alg), bpc; kwargs...)
end

function scale!(bp_cache::AbstractBeliefPropagationCache, args...)
return scale!(tensornetwork(bp_cache), args...)
end

function rescale_messages(
bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge
)
Expand Down
9 changes: 4 additions & 5 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using SplitApplyCombine: group
using LinearAlgebra: diag, dot
using ITensors: dir
using NamedGraphs.PartitionedGraphs:
AbstractPartitionedGraph,
PartitionedGraphs,
PartitionedGraph,
PartitionVertex,
Expand All @@ -23,7 +24,8 @@ function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGr
return (;)
end

struct BeliefPropagationCache{PTN,MTS} <: AbstractBeliefPropagationCache
struct BeliefPropagationCache{V,PV,PTN<:AbstractPartitionedGraph{V,PV},MTS} <:
AbstractBeliefPropagationCache{V,PV}
partitioned_tensornetwork::PTN
messages::MTS
end
Expand Down Expand Up @@ -81,15 +83,12 @@ function default_message_update_kwargs(
return (;)
end

Base.setindex!(bpc::BeliefPropagationCache, factor::ITensor, vertex) = not_implemented()
partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc))
function PartitionedGraphs.partitionedges(bpc::BeliefPropagationCache)
partitionedges(partitioned_tensornetwork(bpc))
end

function set_messages(cache::BeliefPropagationCache, messages)
return BeliefPropagationCache(partitioned_tensornetwork(cache), messages)
end

function environment(bpc::BeliefPropagationCache, verts::Vector; kwargs...)
partition_verts = partitionvertices(bpc, verts)
messages = incoming_messages(bpc, partition_verts; kwargs...)
Expand Down
14 changes: 14 additions & 0 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ITensorNetworks:
ITensorNetworks,
BeliefPropagationCache,
⊗,
@preserve_graph,
combine_linkinds,
contract,
contraction_sequence,
Expand Down Expand Up @@ -48,6 +49,19 @@ using Test: @test, @testset
ψ = random_tensornetwork(rng, elt, s; link_space=χ)
ψψ = ψ ⊗ prime(dag(ψ); sites=[])
bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ)))

#Test updating the tensors in the cache
vket, vbra = ((1, 1), 1), ((1, 1), 2)
A = bpc[vket]
new_A = random_itensor(elt, inds(A))
new_A_dag = ITensors.replaceind(
dag(prime(new_A)), only(s[first(vket)])', only(s[first(vket)])
)
@preserve_graph bpc[vket] = new_A
@preserve_graph bpc[vbra] = new_A_dag
@test bpc[vket] == new_A
@test bpc[vbra] == new_A_dag

bpc = update(bpc; maxiter=25, tol=eps(real(elt)))
#Test messages are converged
for pe in partitionedges(bpc)
Expand Down
Loading