diff --git a/Project.toml b/Project.toml index 01a7775d..fdb515bc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.13.7" +version = "0.13.8" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/caches/abstractbeliefpropagationcache.jl b/src/caches/abstractbeliefpropagationcache.jl index cbdba7cd..640bdadc 100644 --- a/src/caches/abstractbeliefpropagationcache.jl +++ b/src/caches/abstractbeliefpropagationcache.jl @@ -11,9 +11,18 @@ using NamedGraphs.PartitionedGraphs: partitionedges, unpartitioned_graph using SimpleTraits: SimpleTraits, Not, @traitfn +using NamedGraphs.SimilarType: SimilarType using NDTensors: NDTensors -abstract type AbstractBeliefPropagationCache end +abstract type AbstractBeliefPropagationCache{V,PV} <: AbstractITensorNetwork{V} end + +function SimilarType.similar_type(bpc::AbstractBeliefPropagationCache) + return typeof(tensornetwork(bpc)) +end +function data_graph_type(bpc::AbstractBeliefPropagationCache) + return data_graph_type(tensornetwork(bpc)) +end +data_graph(bpc::AbstractBeliefPropagationCache) = data_graph(tensornetwork(bpc)) function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...) sequence = contraction_sequence(contract_list; alg="optimal") @@ -40,6 +49,9 @@ default_messages(ptn::PartitionedGraph) = Dictionary() end default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) +function Base.setindex!(bpc::AbstractBeliefPropagationCache, factor::ITensor, vertex) + return not_implemented() +end partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented() messages(bpc::AbstractBeliefPropagationCache) = not_implemented() function default_message( @@ -88,12 +100,8 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache) return unpartitioned_graph(partitioned_tensornetwork(bpc)) end -function setindex_preserve_graph!(bpc::AbstractBeliefPropagationCache, args...) - return setindex_preserve_graph!(tensornetwork(bpc), args...) -end - function factors(bpc::AbstractBeliefPropagationCache, verts::Vector) - return ITensor[tensornetwork(bpc)[v] for v in verts] + return ITensor[bpc[v] for v in verts] end function factors( @@ -143,7 +151,6 @@ for f in [ :(PartitionedGraphs.partitionvertices), :(PartitionedGraphs.vertices), :(PartitionedGraphs.boundary_partitionedges), - :(linkinds), ] @eval begin function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...) @@ -152,23 +159,28 @@ for f in [ end end +function linkinds(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge) + return linkinds(partitioned_tensornetwork(bpc), pe) +end + NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc)) """ -Update the tensornetwork inside the cache +Update the tensornetwork inside the cache out-of-place """ function update_factors(bpc::AbstractBeliefPropagationCache, factors) bpc = copy(bpc) - tn = tensornetwork(bpc) for vertex in eachindex(factors) # TODO: Add a check that this preserves the graph structure. - setindex_preserve_graph!(tn, factors[vertex], vertex) + setindex_preserve_graph!(bpc, factors[vertex], vertex) end return bpc end function update_factor(bpc, vertex, factor) - return update_factors(bpc, Dictionary([vertex], [factor])) + bpc = copy(bpc) + setindex_preserve_graph!(bpc, factor, vertex) + return bpc end function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...) @@ -178,12 +190,45 @@ end function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...) return map(edge -> message(bpc, edge; kwargs...), edges) end +function set_messages!(bpc::AbstractBeliefPropagationCache, partitionedges_messages) + ms = messages(bpc) + for pe in eachindex(partitionedges_messages) + # TODO: Add a check that this preserves the graph structure. + set!(ms, pe, partitionedges_messages[pe]) + end + return bpc +end +function set_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message) + ms = messages(bpc) + set!(ms, pe, message) + return bpc +end + +function set_messages(bpc::AbstractBeliefPropagationCache, partitionedges_messages) + bpc = copy(bpc) + return set_messages!(bpc, partitionedges_messages) +end function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message) bpc = copy(bpc) + return set_message!(bpc, pe, message) +end +function delete_messages!(bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge}) ms = messages(bpc) - set!(ms, pe, message) + for pe in pes + delete!(ms, pe) + end return bpc end +function delete_message!(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge) + return delete_message!(bpc, [pe]) +end +function delete_messages(bpc::AbstractBeliefPropagationCache, pes::Vector{<:PartitionEdge}) + bpc = copy(bpc) + return delete_messages!(bpc, pes) +end +function delete_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge) + return delete_message(bpc, [pe]) +end """ Compute message tensor as product of incoming mts and local state @@ -241,11 +286,11 @@ function update( edge_groups::Vector{<:Vector{<:PartitionEdge}}; kwargs..., ) - new_mts = copy(messages(bpc)) + new_mts = empty(messages(bpc)) for edges in edge_groups bpc_t = update(alg, bpc, edges; kwargs...) for e in edges - new_mts[e] = message(bpc_t, e) + set!(new_mts, e, message(bpc_t, e)) end end return set_messages(bpc, new_mts) @@ -288,10 +333,6 @@ function update( return update(Algorithm(alg), bpc; kwargs...) end -function scale!(bp_cache::AbstractBeliefPropagationCache, args...) - return scale!(tensornetwork(bp_cache), args...) -end - function rescale_messages( bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge ) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 92fd7f2c..8e4ace5e 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -3,6 +3,7 @@ using SplitApplyCombine: group using LinearAlgebra: diag, dot using ITensors: dir using NamedGraphs.PartitionedGraphs: + AbstractPartitionedGraph, PartitionedGraphs, PartitionedGraph, PartitionVertex, @@ -23,7 +24,8 @@ function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGr return (;) end -struct BeliefPropagationCache{PTN,MTS} <: AbstractBeliefPropagationCache +struct BeliefPropagationCache{V,PV,PTN<:AbstractPartitionedGraph{V,PV},MTS} <: + AbstractBeliefPropagationCache{V,PV} partitioned_tensornetwork::PTN messages::MTS end @@ -81,15 +83,12 @@ function default_message_update_kwargs( return (;) end +Base.setindex!(bpc::BeliefPropagationCache, factor::ITensor, vertex) = not_implemented() partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc)) function PartitionedGraphs.partitionedges(bpc::BeliefPropagationCache) partitionedges(partitioned_tensornetwork(bpc)) end -function set_messages(cache::BeliefPropagationCache, messages) - return BeliefPropagationCache(partitioned_tensornetwork(cache), messages) -end - function environment(bpc::BeliefPropagationCache, verts::Vector; kwargs...) partition_verts = partitionvertices(bpc, verts) messages = incoming_messages(bpc, partition_verts; kwargs...) diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index 5877dd45..bfe67da5 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -6,6 +6,7 @@ using ITensorNetworks: ITensorNetworks, BeliefPropagationCache, ⊗, + @preserve_graph, combine_linkinds, contract, contraction_sequence, @@ -48,6 +49,19 @@ using Test: @test, @testset ψ = random_tensornetwork(rng, elt, s; link_space=χ) ψψ = ψ ⊗ prime(dag(ψ); sites=[]) bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ))) + + #Test updating the tensors in the cache + vket, vbra = ((1, 1), 1), ((1, 1), 2) + A = bpc[vket] + new_A = random_itensor(elt, inds(A)) + new_A_dag = ITensors.replaceind( + dag(prime(new_A)), only(s[first(vket)])', only(s[first(vket)]) + ) + @preserve_graph bpc[vket] = new_A + @preserve_graph bpc[vbra] = new_A_dag + @test bpc[vket] == new_A + @test bpc[vbra] == new_A_dag + bpc = update(bpc; maxiter=25, tol=eps(real(elt))) #Test messages are converged for pe in partitionedges(bpc)