diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 19c4109..cca4b6d 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,4 +9,8 @@ include("abstract_problem.jl") include("iterators.jl") include("adapters.jl") +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") + end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index e566752..b02c789 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture using NamedDimsArrays: dimnames, inds using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree -using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!, - rename_vertices, vertextype +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +using NamedGraphs.GraphsExtensions: + ⊔, + directed_graph, + incident_edges, + rem_edges!, + rename_vertices, + vertextype using SplitApplyCombine: flatten +using NamedGraphs.SimilarType: similar_type abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end -function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn -end +# Need to be careful about removing edges from tensor networks in case there is a bond +Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -# TODO: Define a generic fallback for `AbstractDataGraph`? -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") +DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork) end # Copy -Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") +Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) @@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# Derived interface, may need to be overloaded -function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) -end - # AbstractDataGraphs overloads -function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end -function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end +DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() +DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) end @@ -81,40 +76,37 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) return map_vertex_data_preserve_graph(adapt(to), tn) end -function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) -end -function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) -end -function linkaxes(tn::AbstractTensorNetwork, edge::Pair) +linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge)) +linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) + +function linkaxes(tn::AbstractGraph, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) end -function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linkaxes(tn::AbstractGraph, edge::AbstractEdge) return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end -function linknames(tn::AbstractTensorNetwork, edge::Pair) +function linknames(tn::AbstractGraph, edge::Pair) return linknames(tn, edgetype(tn)(edge)) end -function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linknames(tn::AbstractGraph, edge::AbstractEdge) return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end -function siteinds(tn::AbstractTensorNetwork, v) +function siteinds(tn::AbstractGraph, v) s = inds(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, inds(tn[v′])) end return s end -function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function siteaxes(tn::AbstractGraph, edge::AbstractEdge) s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function sitenames(tn::AbstractGraph, edge::AbstractEdge) s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) @@ -122,8 +114,8 @@ function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) return s end -function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value +function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) + set!(vertex_data(tn), vertex, value) return tn end @@ -153,7 +145,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork) +function add_missing_edges!(tn::AbstractGraph) foreach(v -> add_missing_edges!(tn, v), vertices(tn)) return tn end @@ -161,7 +153,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork, v) +function add_missing_edges!(tn::AbstractGraph, v) for v′ in vertices(tn) if v ≠ v′ e = v => v′ @@ -175,13 +167,13 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. -function fix_edges!(tn::AbstractTensorNetwork) +function fix_edges!(tn::AbstractGraph) foreach(v -> fix_edges!(tn, v), vertices(tn)) return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. -function fix_edges!(tn::AbstractTensorNetwork, v) +function fix_edges!(tn::AbstractGraph, v) rem_edges!(tn, incident_edges(tn, v)) add_missing_edges!(tn, v) return tn @@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v) fix_edges!(tn, v) return tn end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) graph[vertices(graph)[vertex]] = value return graph end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") -end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") -end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented() +Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() # Fix ambiguity error. function Base.setindex!( tn::AbstractTensorNetwork, value, edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, ) - return error("No edge data.") + return not_implemented() end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) @@ -255,3 +239,21 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) + +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} + return tensornetwork_induced_subgraph(graph, subvertices) +end +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + subgraph = similar_type(graph)(underlying_subgraph) + for v in vertices(subgraph) + if isassigned(graph, v) + set!(vertex_data(subgraph), v, graph[v]) + end + end + return subgraph, vlist +end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..8c6b3dd --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,133 @@ +using Graphs: AbstractGraph, AbstractEdge +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent + +messages(::AbstractGraph) = not_implemented() +messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] + +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] + +deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() +function deletemessage!(bp_cache::AbstractDataGraph, edge) + ms = messages(bp_cache) + delete!(ms, edge) + return bp_cache +end + +function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() +function setmessage!(bp_cache::AbstractDataGraph, edge, message) + ms = messages(bp_cache) + set!(ms, edge, message) + return bp_cache +end +function setmessage!(bp_cache::QuotientView, edge, message) + setmessages!(parent(bp_cache), QuotientEdge(edge), message) + return bp_cache +end + +function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) + for e in edges(bp_cache, edge) + setmessage!(parent(bp_cache), e, message[e]) + end + return bp_cache +end +function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) + for e in edges + setmessage!(bpc_dst, e, message(bpc_src, e)) + end + return bpc_dst +end + +factors(bpc::AbstractGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] +factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) + +factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] + +setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() +function setfactor!(bpc::AbstractDataGraph, vertex, factor) + fs = factors(bpc) + set!(fs, vertex, factor) + return bpc +end + +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) + return message(bp_cache, edge) * message(bp_cache, reverse(edge)) +end + +function region_scalar(bp_cache::AbstractGraph, vertex) + + messages = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + + return reduce(*, messages) * reduce(*, state) +end + +message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) +message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) + +function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) + return map(v -> region_scalar(bp_cache, v), vertices) +end + +function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) + return map(e -> region_scalar(bp_cache, e), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractGraph) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +default_messages(::AbstractGraph) = not_implemented() + +#Adapt interface for changing device +map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) +function map_messages!(f, bp_cache, es = edges(bp_cache)) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end + +map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) +function map_factors!(f, bp_cache, vs = vertices(bp_cache)) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end + +adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) +adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) + +abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end + +function free_energy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end +partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..5d8fa35 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,92 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using Dictionaries: Dictionary, set!, delete! +using Graphs: AbstractGraph, is_tree, connected_components +using NamedGraphs: convert_vertextype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph + +struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: + AbstractBeliefPropagationCache{V, MT} + network::N + messages::Dictionary{ET, MT} +end + +network(bp_cache) = underlying_graph(bp_cache) + +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) +DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) +DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) +function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) + return fieldtype(type, :network) +end + +message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT + +function BeliefPropagationCache(alg, network::AbstractGraph) + es = collect(edges(network)) + es = vcat(es, reverse.(es)) + messages = map(edge -> default_message(alg, network, edge), es) + return BeliefPropagationCache(network, Dictionary(es, messages)) +end + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +end + +# TODO: This needs to go in DataGraphsGraphsExtensionsExt +# +# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges +# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, +# hence we just strip off any `AbstractDataGraph` data to avoid this. +function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) + return forest_cover_edge_sequence(underlying_graph(g); kwargs...) +end +# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt +# +# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the +# data of a data graph to be removed using the above method if `parent_type(g)` is an +# `AbstractDataGraph`. +function forest_cover_edge_sequence(g::QuotientView; kwargs...) + return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) +end +# TODO: This needs to go in GraphsExtensions +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + add_edges!(g, edges(g)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return rv +end + +function bpcache_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) + subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + for e in edges(subgraph) + if isassigned(graph, e) + set!(edge_data(subgraph), e, graph[e]) + end + end + return subgraph, vlist +end + +function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) + return bpcache_induced_subgraph(graph, subvertices) +end +# For method ambiguity +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} + return bpcache_induced_subgraph(graph, subvertices) +end + +## PartitionedGraphs + +function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) + qview = QuotientView(network(bpc)) + messages = edge_data(QuotientView(bpc)) + return BeliefPropagationCache(qview, messages) +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..49d0ef8 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,134 @@ +using Graphs: SimpleGraph, vertices, edges, has_edge +using NamedGraphs: AbstractNamedGraph, position_graph +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices + +abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end + +mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} + const alg::Alg + const cache::Cache + diff::Union{Nothing, Float64} +end + +BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) + +function default_algorithm( + ::Type{<:Algorithm"bp"}, + bpc; + verbose = false, + tol = nothing, + edge_sequence = forest_cover_edge_sequence(bpc), + message_update_alg = default_algorithm(Algorithm"contract"), + maxiter = is_tree(bpc) ? 1 : nothing, + ) + return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) +end + +function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) + edges = prob.alg.edge_sequence + + plan = map(edges) do e + return e => (; sweep_kwargs...) + end + + return plan +end + +function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) + prob = iter.problem + + edge, _ = current_region_plan(iter) + new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) + setmessage!(prob.cache, edge, new_message) + + return iter +end + +default_message(alg, network, edge) = default_message(typeof(alg), network, edge) + +default_message(::Type{<:Algorithm}, network, edge) = not_implemented() +function default_message(::Type{<:Algorithm"bp"}, network, edge) + + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + links = linkinds(network, edge) + data = ones(Tuple(links)) + return data +end + +updated_message(alg, bpc, edge) = not_implemented() +function updated_message(alg::Algorithm"contract", bpc, edge) + vertex = src(edge) + + incoming_ms = incoming_messages( + bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] + ) + + updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) + + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end + end + return updated_message +end + +contract_messages(alg, factors, messages) = not_implemented() +function contract_messages( + alg, + factors::Vector{<:AbstractArray}, + messages::Vector{<:AbstractArray}, + ) + return contract_network(alg, vcat(factors, messages)) +end + +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") + ) + return Algorithm("contract"; normalize, contraction_alg) +end +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") + ) + return Algorithm("adapt_update"; adapt, alg) +end + +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) +end + +function update(bpc::AbstractBeliefPropagationCache; kwargs...) + return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) +end + +function update(alg, bpc) + compute_error = !isnothing(alg.tol) + + diff = compute_error ? 0.0 : nothing + + prob = BeliefPropagationProblem(alg, bpc, diff) + + iter = SweepIterator(prob, alg.maxiter; compute_error) + + for _ in iter + if compute_error && prob.diff <= alg.tol + break + end + end + + if alg.verbose && compute_error + if prob.diff <= alg.tol + println("BP converged to desired precision after $(iter.which_sweep) iterations.") + else + println( + "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", + ) + end + end + + return bpc +end diff --git a/src/iterators.jl b/src/iterators.jl index 62d5b21..568cb5d 100644 --- a/src/iterators.jl +++ b/src/iterators.jl @@ -54,8 +54,8 @@ mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator end end -function RegionIterator(problem; sweep, sweep_kwargs...) - plan = region_plan(problem; sweep_kwargs...) +function RegionIterator(problem, prevplan; sweep, sweep_kwargs...) + plan = region_plan(problem, prevplan; sweep_kwargs...) return RegionIterator(problem, plan, sweep) end @@ -109,6 +109,8 @@ function compute!(iter::RegionIterator) return iter end +# Default behaviour: +region_plan(problem, ::Any; sweep_kwargs...) = region_plan(problem; sweep_kwargs...) region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) # @@ -129,7 +131,7 @@ mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator end first_kwargs, _ = first_state - region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) + region_iter = RegionIterator(problem, nothing; sweep = 1, first_kwargs...) return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) end @@ -151,7 +153,10 @@ end function update_region_iterator!(iterator::SweepIterator; kwargs...) sweep = state(iterator) - iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) + + previous_plan = iterator.region_iter.region_plan + + iterator.region_iter = RegionIterator(problem(iterator), previous_plan; sweep, kwargs...) return iterator end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 582eec6..0681da5 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,10 +1,21 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary -using Graphs: AbstractSimpleGraph +using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype +using NamedGraphs.PartitionedGraphs: + AbstractPartitionedGraph, + PartitionedGraphs, + departition, + partitioned_vertices, + partitionedgraph, + quotient_graph, + quotient_graph_type +using .LazyNamedDimsArrays: lazy, Mul +using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -24,8 +35,14 @@ function _TensorNetwork(graph::AbstractGraph, tensors) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end +function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} + return _TensorNetwork(graph, Tensors()) +end + DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() +DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -49,8 +66,7 @@ end tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) - tensors = Dictionary(vertices(graph), f.(vertices(graph))) - return TensorNetwork(graph, tensors) + return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end function TensorNetwork(graph::AbstractGraph, tensors) tn = _TensorNetwork(graph, tensors) @@ -93,3 +109,56 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) + +Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) + +function Graphs.rem_edge!(tn::TensorNetwork, e) + if !has_edge(underlying_graph(tn), e) + return false + end + if !isempty(linkinds(tn, e)) + throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + end + rem_edge!(underlying_graph(tn), e) + return true +end + +function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) + DT = fieldtype(type, :tensors) + empty_dict = DT() + return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) +end + +## PartitionedGraphs +function PartitionedGraphs.quotient_graph(tn::TensorNetwork) + ug = quotient_graph(underlying_graph(tn)) + return TensorNetwork(ug, vertex_data(QuotientView(tn))) +end +function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) + UG = quotient_graph_type(underlying_graph_type(type)) + VD = Vector{vertex_data_eltype(type)} + V = vertextype(UG) + return TensorNetwork{V, VD, UG, Dictionary{V, VD}} +end + +function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) + pg = partitionedgraph(underlying_graph(tn), parts) + return TensorNetwork(pg, vertex_data(tn)) +end + +PartitionedGraphs.departition(tn::TensorNetwork) = tn +function PartitionedGraphs.departition( + tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph} + ) + return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) +end + +function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) + return mapreduce(lazy, *, collect(last(data))) +end + +function PartitionedGraphs.quotientview(tn::TensorNetwork) + qview = QuotientView(underlying_graph(tn)) + tensors = vertex_data(QuotientView(tn)) + return TensorNetwork(qview, tensors) +end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..a39e1a6 --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,53 @@ +using Dictionaries: Dictionary +using ITensorBase: Index +using ITensorNetworksNext: + BeliefPropagationCache, + ITensorNetworksNext, + TensorNetwork, + adapt_messages, + default_message, + default_messages, + edge_scalars, + factors, + messages, + partitionfunction, + setmessages! +using Graphs: edges, vertices +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using Test: @test, @testset + +@testset "BeliefPropagation" begin + + #Chain of tensors + dims = (4, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1.0e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1.0e-14 +end