From ed58440aa34a118d608122ba4930dd56f05ff7ed Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 2 Oct 2025 13:25:12 -0400 Subject: [PATCH 01/15] Working BP Commit --- src/ITensorNetworksNext.jl | 3 +++ src/abstracttensornetwork.jl | 2 +- test/test_beliefpropagation.jl | 25 +++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 test/test_beliefpropagation.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 19c4109..905d783 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,4 +9,7 @@ include("abstract_problem.jl") include("iterators.jl") include("adapters.jl") +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") + end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index e566752..1ecbffa 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -254,4 +254,4 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..4b179fb --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,25 @@ +using Dictionaries: Dictionary +using ITensorBase: Index +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, + partitionfunction +using Graphs: edges, vertices +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using Test: @test, @testset + +@testset "BeliefPropagation" begin + dims = (4, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 +end \ No newline at end of file From fe027a190ae63b146740ed81d6b225a74af45d74 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 23 Oct 2025 18:23:27 -0400 Subject: [PATCH 02/15] BP Code --- .../abstractbeliefpropagationcache.jl | 151 +++++++++++ .../beliefpropagationcache.jl | 237 ++++++++++++++++++ test/test_beliefpropagation.jl | 20 +- 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 src/beliefpropagation/abstractbeliefpropagationcache.jl create mode 100644 src/beliefpropagation/beliefpropagationcache.jl diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..5eae283 --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,151 @@ +abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end + +#Interface +factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() +setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() +messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() +function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) + return not_implemented() +end +function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +function rescale_messages( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... + ) + return not_implemented() +end +function rescale_vertices( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... + ) + return not_implemented() +end + +function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return not_implemented() +end +function edge_scalar( + bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... + ) + return not_implemented() +end + +#Graph functionality needed +Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function NamedGraphs.GraphsExtensions.boundary_edges( + bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... + ) + return not_implemented() +end + +#Functions derived from the interface +function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) + for (e, m) in zip(edges) + setmessage!(bp_cache, e, m) + end + return +end + +function deletemessages!( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) + ) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +function vertex_scalars( + bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... + ) + return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +end + +function edge_scalars( + bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... + ) + return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] + ) + b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return incoming_messages(bp_cache, [vertex]; kwargs...) +end + +#Adapt interface for changing device +function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) + bp_cache = copy(bp_cache) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end +function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) + bp_cache = copy(bp_cache) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end +function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_messages(adapt(to), bp_cache, args...) +end +function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_factors(adapt(to), bp_cache, args...) +end + +function freenergy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end + +function partitionfunction(bp_cache::AbstractBeliefPropagationCache) + return exp(freenergy(bp_cache)) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return rescale_messages(bp_cache, [edge]) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache) + return rescale_messages(bp_cache, edges(bp_cache)) +end + +function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) + return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) +end + +function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) + return rescale_vertices(bpc, [vertex]; kwargs...) +end + +function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + bpc = rescale_messages(bpc) + bpc = rescale_partitions(bpc, args...; kwargs...) + return bpc +end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..295502a --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,237 @@ +using DiagonalArrays: delta +using Dictionaries: Dictionary, set!, delete! +using Graphs: AbstractGraph, is_tree, connected_components +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using ITensorBase: ITensor, dim +using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype + +struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: + AbstractBeliefPropagationCache{V} + network::N + messages::Dictionary +end + +messages(bp_cache::BeliefPropagationCache) = bp_cache.messages +network(bp_cache::BeliefPropagationCache) = bp_cache.network +default_messages() = Dictionary() + +BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +end + +function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) + ms = messages(bp_cache) + delete!(ms, e) + return bp_cache +end + +function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) + ms = messages(bp_cache) + set!(ms, e, message) + return bp_cache +end + +function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) + ms = messages(bp_cache) + return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +end + +function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) + return [message(bp_cache, e) for e in edges] +end + +default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing +#Forward onto the network +for f in [ + :(Graphs.vertices), + :(Graphs.edges), + :(Graphs.is_tree), + :(NamedGraphs.GraphsExtensions.boundary_edges), + :(factors), + :(default_bp_maxiter), + :(ITensorNetworksNext.setfactor!), + :(ITensorNetworksNext.linkinds), + :(ITensorNetworksNext.underlying_graph), + ] + @eval begin + function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) + return $f(network(bp_cache), args...; kwargs...) + end + end +end + +#TODO: Get subgraph working on an ITensorNetwork to overload this directly +function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) + return forest_cover_edge_sequence(underlying_graph(bp_cache)) +end + +function factors(tn::AbstractTensorNetwork, vertex) + return [tn[vertex]] +end + +function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::BeliefPropagationCache, vertex) + incoming_ms = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + return (reduce(*, incoming_ms) * reduce(*, state))[] +end + +function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return default_message(network(bp_cache), edge::AbstractEdge) +end + +function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) + t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + return t +end + +#Algorithmic defaults +default_update_alg(bp_cache::BeliefPropagationCache) = "bp" +default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" +default_normalize(::Algorithm"contract") = true +default_sequence_alg(::Algorithm"contract") = "optimal" +function set_default_kwargs(alg::Algorithm"contract") + normalize = get(alg, :normalize, default_normalize(alg)) + sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) + return Algorithm("contract"; normalize, sequence_alg) +end +function set_default_kwargs(alg::Algorithm"adapt_update") + _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) + return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) +end +default_verbose(::Algorithm"bp") = false +default_tol(::Algorithm"bp") = nothing +function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) + verbose = get(alg, :verbose, default_verbose(alg)) + maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) + edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) + tol = get(alg, :tol, default_tol(alg)) + message_update_alg = set_default_kwargs( + get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) + ) + return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) +end + +#TODO: Update message etc should go here... +function updated_message( + alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + vertex = src(edge) + incoming_ms = incoming_messages( + bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] + ) + state = factors(bp_cache, vertex) + #contract_list = ITensor[incoming_ms; state] + #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) + #updated_messages = contract(contract_list; sequence) + updated_message = + !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end + end + return updated_message +end + +function updated_message( + bp_cache::BeliefPropagationCache, + edge::AbstractEdge; + alg = default_message_update_alg(bpc), + kwargs..., + ) + return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) +end + +function update_message!( + message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edges::Vector; + (update_diff!) = nothing, + ) + bpc = copy(bpc) + for e in edges + prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing + update_message!(alg.message_update_alg, bpc, e) + if !isnothing(update_diff!) + update_diff![] += message_diff(message(bpc, e), prev_message) + end + end + return bpc +end + +""" +Do parallel updates between groups of edges of all message tensors +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the +mts relevant to that group. +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edge_groups::Vector{<:Vector{<:AbstractEdge}}; + (update_diff!) = nothing, + ) + new_mts = empty(messages(bpc)) + for edges in edge_groups + bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) + for e in edges + set!(new_mts, e, message(bpc_t, e)) + end + end + return set_messages(bpc, new_mts) +end + +""" +More generic interface for update, with default params +""" +function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) + compute_error = !isnothing(alg.tol) + if isnothing(alg.maxiter) + error("You need to specify a number of iterations for BP!") + end + for i in 1:alg.maxiter + diff = compute_error ? Ref(0.0) : nothing + bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) + if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol + if alg.verbose + println("BP converged to desired precision after $i iterations.") + end + break + end + end + return bpc +end + +function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) + return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) +end + +#Edge sequence stuff +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + forests = forest_cover(g) + edges = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return edges +end \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 4b179fb..81ee722 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,11 +3,13 @@ using ITensorBase: Index using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, partitionfunction using Graphs: edges, vertices -using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using Test: @test, @testset @testset "BeliefPropagation" begin + + #Chain of tensors dims = (4, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) @@ -17,6 +19,22 @@ using Test: @test, @testset return randn(Tuple(is)) end + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) From 993fc7a8b96860a7818b6057cb1227f84fab6f4a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Oct 2025 15:18:28 -0400 Subject: [PATCH 03/15] Express BP in terms of `SweepIterator` interface Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling. --- Project.toml | 2 + src/ITensorNetworksNext.jl | 1 + .../beliefpropagationcache.jl | 126 ++---------------- .../beliefpropagationproblem.jl | 85 ++++++++++++ 4 files changed, 101 insertions(+), 113 deletions(-) create mode 100644 src/beliefpropagation/beliefpropagationproblem.jl diff --git a/Project.toml b/Project.toml index 85efef2..f892a74 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -39,6 +40,7 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" +ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 905d783..cca4b6d 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -11,5 +11,6 @@ include("adapters.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 295502a..cdae651 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,9 +1,7 @@ -using DiagonalArrays: delta using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges using ITensorBase: ITensor, dim -using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: AbstractBeliefPropagationCache{V} @@ -13,9 +11,8 @@ end messages(bp_cache::BeliefPropagationCache) = bp_cache.messages network(bp_cache::BeliefPropagationCache) = bp_cache.network -default_messages() = Dictionary() -BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) +BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) @@ -33,16 +30,15 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end -function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) +function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) end -function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) +function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) return [message(bp_cache, e) for e in edges] end -default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing #Forward onto the network for f in [ :(Graphs.vertices), @@ -62,11 +58,6 @@ for f in [ end end -#TODO: Get subgraph working on an ITensorNetwork to overload this directly -function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) - return forest_cover_edge_sequence(underlying_graph(bp_cache)) -end - function factors(tn::AbstractTensorNetwork, vertex) return [tn[vertex]] end @@ -91,33 +82,6 @@ function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) return t end -#Algorithmic defaults -default_update_alg(bp_cache::BeliefPropagationCache) = "bp" -default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" -default_normalize(::Algorithm"contract") = true -default_sequence_alg(::Algorithm"contract") = "optimal" -function set_default_kwargs(alg::Algorithm"contract") - normalize = get(alg, :normalize, default_normalize(alg)) - sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) - return Algorithm("contract"; normalize, sequence_alg) -end -function set_default_kwargs(alg::Algorithm"adapt_update") - _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) - return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) -end -default_verbose(::Algorithm"bp") = false -default_tol(::Algorithm"bp") = nothing -function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) - verbose = get(alg, :verbose, default_verbose(alg)) - maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) - edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) - tol = get(alg, :tol, default_tol(alg)) - message_update_alg = set_default_kwargs( - get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) - ) - return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) -end - #TODO: Update message etc should go here... function updated_message( alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge @@ -141,85 +105,21 @@ function updated_message( return updated_message end -function updated_message( - bp_cache::BeliefPropagationCache, - edge::AbstractEdge; - alg = default_message_update_alg(bpc), - kwargs..., +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" ) - return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) + return Algorithm("contract"; normalize, sequence_alg) end - -function update_message!( - message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") ) - return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) + return Algorithm("adapt_update"; adapt, alg) end -""" -Do a sequential update of the message tensors on `edges` -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edges::Vector; - (update_diff!) = nothing, - ) - bpc = copy(bpc) - for e in edges - prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing - update_message!(alg.message_update_alg, bpc, e) - if !isnothing(update_diff!) - update_diff![] += message_diff(message(bpc, e), prev_message) - end - end - return bpc -end - -""" -Do parallel updates between groups of edges of all message tensors -Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edge_groups::Vector{<:Vector{<:AbstractEdge}}; - (update_diff!) = nothing, +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge ) - new_mts = empty(messages(bpc)) - for edges in edge_groups - bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) - for e in edges - set!(new_mts, e, message(bpc_t, e)) - end - end - return set_messages(bpc, new_mts) -end - -""" -More generic interface for update, with default params -""" -function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) - compute_error = !isnothing(alg.tol) - if isnothing(alg.maxiter) - error("You need to specify a number of iterations for BP!") - end - for i in 1:alg.maxiter - diff = compute_error ? Ref(0.0) : nothing - bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) - if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol - if alg.verbose - println("BP converged to desired precision after $i iterations.") - end - break - end - end - return bpc -end - -function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) - return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end #Edge sequence stuff @@ -234,4 +134,4 @@ function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root end end return edges -end \ No newline at end of file +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..a497363 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,85 @@ +mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: + AbstractProblem + const cache::Cache + diff::Union{Nothing, Float64} +end + +function default_algorithm( + ::Type{<:Algorithm"bp"}, + bpc::BeliefPropagationCache; + verbose = false, + tol = nothing, + edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + message_update_alg = default_algorithm(Algorithm"contract"), + maxiter = is_tree(bpc) ? 1 : nothing, + ) + return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) +end + +function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) + prob = iter.problem + + edge_group, kwargs = current_region_plan(iter) + + new_message_tensors = map(edge_group) do edge + old_message = message(prob.cache, edge) + + new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) + + if !isnothing(prob.diff) + # TODO: Define `message_diff` + prob.diff += message_diff(new_message, old_message) + end + + return new_message + end + + foreach(edge_group, new_message_tensors) do edge, new_message + setmessage!(prob.cache, edge, new_message) + end + + return iter +end + +function region_plan( + prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... + ) + edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + plan = map(edges) do e + return [e] => (; sweep_kwargs...) + end + + return plan +end + +function update(bpc::AbstractBeliefPropagationCache; kwargs...) + return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) +end +function update(alg::Algorithm"bp", bpc) + compute_error = !isnothing(alg.tol) + + diff = compute_error ? 0.0 : nothing + + prob = BeliefPropagationProblem(bpc, diff) + + iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + + for _ in iter + if compute_error && prob.diff <= alg.tol + break + end + end + + if alg.verbose && compute_error + if prob.diff <= alg.tol + println("BP converged to desired precision after $(iter.which_sweep) iterations.") + else + println( + "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", + ) + end + end + + return bpc +end From b83ee3684f6c983c9fb9beed19effabc9ee01fd3 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 31 Oct 2025 12:46:03 -0400 Subject: [PATCH 04/15] Add method for `setmessages!` that allows messages from one cache to be set from another cache --- src/beliefpropagation/beliefpropagationcache.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index cdae651..b3a32b1 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -30,6 +30,14 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end +function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) + ms_dst = messages(bpc_dst) + for e in edges + set!(ms_dst, e, message(bpc_src, e)) + end + return bpc_dst +end + function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) From 6161cb32cb31f5bc951eb7d1e51d5de4aec38c8a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 10 Nov 2025 14:03:59 -0500 Subject: [PATCH 05/15] Network is now passed to `forest_cover_edge_sequence` directly. --- src/beliefpropagation/beliefpropagationproblem.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a497363..967b454 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -9,7 +9,7 @@ function default_algorithm( bpc::BeliefPropagationCache; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + edge_sequence = forest_cover_edge_sequence(network(bpc)), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) @@ -44,7 +44,8 @@ end function region_plan( prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... ) - edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) plan = map(edges) do e return [e] => (; sweep_kwargs...) From 7440149b654b45f963b8f7837606754b426c28af Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:19:31 -0500 Subject: [PATCH 06/15] test file formatting --- test/test_beliefpropagation.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 81ee722..fc657e7 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,7 +1,17 @@ using Dictionaries: Dictionary using ITensorBase: Index -using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, - partitionfunction +using ITensorNetworksNext: + BeliefPropagationCache, + ITensorNetworksNext, + TensorNetwork, + adapt_messages, + default_message, + default_messages, + edge_scalars, + factors, + messages, + partitionfunction, + setmessages! using Graphs: edges, vertices using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges @@ -15,15 +25,15 @@ using Test: @test, @testset l = Dict(e => Index(2) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 + @test abs(z_bp - z_exact) <= 1.0e-14 #Tree of tensors dims = (4, 3) @@ -31,13 +41,14 @@ using Test: @test, @testset l = Dict(e => Index(3) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 -end \ No newline at end of file + @test abs(z_bp - z_exact) <= 1.0e-14 +end + From 43500089c39ba0858968547f68a97530f3c566c5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:23:14 -0500 Subject: [PATCH 07/15] `region_plan` now takes the previous region plan as a second argument Upon construction, this defaults to `nothing`. --- src/iterators.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/iterators.jl b/src/iterators.jl index 62d5b21..568cb5d 100644 --- a/src/iterators.jl +++ b/src/iterators.jl @@ -54,8 +54,8 @@ mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator end end -function RegionIterator(problem; sweep, sweep_kwargs...) - plan = region_plan(problem; sweep_kwargs...) +function RegionIterator(problem, prevplan; sweep, sweep_kwargs...) + plan = region_plan(problem, prevplan; sweep_kwargs...) return RegionIterator(problem, plan, sweep) end @@ -109,6 +109,8 @@ function compute!(iter::RegionIterator) return iter end +# Default behaviour: +region_plan(problem, ::Any; sweep_kwargs...) = region_plan(problem; sweep_kwargs...) region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) # @@ -129,7 +131,7 @@ mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator end first_kwargs, _ = first_state - region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) + region_iter = RegionIterator(problem, nothing; sweep = 1, first_kwargs...) return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) end @@ -151,7 +153,10 @@ end function update_region_iterator!(iterator::SweepIterator; kwargs...) sweep = state(iterator) - iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) + + previous_plan = iterator.region_iter.region_plan + + iterator.region_iter = RegionIterator(problem(iterator), previous_plan; sweep, kwargs...) return iterator end From 0802355c08244360b99d55e5d1da5baa36120fc6 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:25:31 -0500 Subject: [PATCH 08/15] Add `DataGraphsPartitionedGraphsExt` glue for `TensorNetwork` type Also includes some fixes to the way `TensorNetwork` types are constructed based on index structure. --- src/tensornetwork.jl | 79 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 582eec6..11c2e88 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,10 +1,21 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary -using Graphs: AbstractSimpleGraph +using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: + AbstractPartitionedGraph, + PartitionedGraphs, + departition, + partitioned_vertices, + partitionedgraph, + quotient_graph, + quotient_graph_type +using .LazyNamedDimsArrays: lazy, Mul +using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -24,8 +35,14 @@ function _TensorNetwork(graph::AbstractGraph, tensors) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end +function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} + return _TensorNetwork(graph, Tensors()) +end + DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() +DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -70,7 +87,10 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - return tn + for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) + insert_trivial_link!(network, edge) + end + return network end # Determine the graph structure from the tensors. @@ -93,3 +113,56 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) + +Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) + +function Graphs.rem_edge!(tn::TensorNetwork, e) + if !has_edge(underlying_graph(tn), e) + return false + end + if !isempty(linkinds(tn, e)) + throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + end + rem_edge!(underlying_graph(tn), e) + return true +end + +function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) + DT = fieldtype(type, :tensors) + empty_dict = DT() + return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) +end + +## PartitionedGraphs +function PartitionedGraphs.quotient_graph(tn::TensorNetwork) + ug = quotient_graph(underlying_graph(tn)) + return TensorNetwork(ug, vertex_data(QuotientView(tn))) +end +function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) + UG = quotient_graph_type(underlying_graph_type(type)) + VD = Vector{vertex_data_eltype(type)} + V = vertextype(UG) + return TensorNetwork{V, VD, UG, Dictionary{V, VD}} +end + +function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) + pg = partitionedgraph(underlying_graph(tn), parts) + return TensorNetwork(pg, vertex_data(tn)) +end + +PartitionedGraphs.departition(tn::TensorNetwork) = tn +function PartitionedGraphs.departition( + tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph} + ) + return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) +end + +function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) + return mapreduce(lazy, *, collect(last(data))) +end + +function PartitionedGraphs.quotientview(tn::TensorNetwork) + qview = QuotientView(underlying_graph(tn)) + tensors = vertex_data(QuotientView(tn)) + return TensorNetwork(qview, tensors) +end From 8f93ec816fa0801027211a145b42742745006650 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:20 -0500 Subject: [PATCH 09/15] Make abstract tensor network interface more generic. --- src/abstracttensornetwork.jl | 106 ++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 1ecbffa..b02c789 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture using NamedDimsArrays: dimnames, inds using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree -using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!, - rename_vertices, vertextype +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +using NamedGraphs.GraphsExtensions: + ⊔, + directed_graph, + incident_edges, + rem_edges!, + rename_vertices, + vertextype using SplitApplyCombine: flatten +using NamedGraphs.SimilarType: similar_type abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end -function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn -end +# Need to be careful about removing edges from tensor networks in case there is a bond +Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -# TODO: Define a generic fallback for `AbstractDataGraph`? -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") +DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork) end # Copy -Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") +Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) @@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# Derived interface, may need to be overloaded -function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) -end - # AbstractDataGraphs overloads -function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end -function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end +DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() +DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) end @@ -81,40 +76,37 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) return map_vertex_data_preserve_graph(adapt(to), tn) end -function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) -end -function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) -end -function linkaxes(tn::AbstractTensorNetwork, edge::Pair) +linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge)) +linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) + +function linkaxes(tn::AbstractGraph, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) end -function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linkaxes(tn::AbstractGraph, edge::AbstractEdge) return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end -function linknames(tn::AbstractTensorNetwork, edge::Pair) +function linknames(tn::AbstractGraph, edge::Pair) return linknames(tn, edgetype(tn)(edge)) end -function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linknames(tn::AbstractGraph, edge::AbstractEdge) return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end -function siteinds(tn::AbstractTensorNetwork, v) +function siteinds(tn::AbstractGraph, v) s = inds(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, inds(tn[v′])) end return s end -function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function siteaxes(tn::AbstractGraph, edge::AbstractEdge) s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function sitenames(tn::AbstractGraph, edge::AbstractEdge) s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) @@ -122,8 +114,8 @@ function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) return s end -function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value +function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) + set!(vertex_data(tn), vertex, value) return tn end @@ -153,7 +145,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork) +function add_missing_edges!(tn::AbstractGraph) foreach(v -> add_missing_edges!(tn, v), vertices(tn)) return tn end @@ -161,7 +153,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork, v) +function add_missing_edges!(tn::AbstractGraph, v) for v′ in vertices(tn) if v ≠ v′ e = v => v′ @@ -175,13 +167,13 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. -function fix_edges!(tn::AbstractTensorNetwork) +function fix_edges!(tn::AbstractGraph) foreach(v -> fix_edges!(tn, v), vertices(tn)) return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. -function fix_edges!(tn::AbstractTensorNetwork, v) +function fix_edges!(tn::AbstractGraph, v) rem_edges!(tn, incident_edges(tn, v)) add_missing_edges!(tn, v) return tn @@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v) fix_edges!(tn, v) return tn end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) graph[vertices(graph)[vertex]] = value return graph end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") -end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") -end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented() +Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() # Fix ambiguity error. function Base.setindex!( tn::AbstractTensorNetwork, value, edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, ) - return error("No edge data.") + return not_implemented() end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) @@ -254,4 +238,22 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) + +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} + return tensornetwork_induced_subgraph(graph, subvertices) +end +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + subgraph = similar_type(graph)(underlying_subgraph) + for v in vertices(subgraph) + if isassigned(graph, v) + set!(vertex_data(subgraph), v, graph[v]) + end + end + return subgraph, vlist +end From 8bd8c3507832b236a42459c1b730ecee6c15ad25 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:50 -0500 Subject: [PATCH 10/15] BP Caching overhauls --- .../abstractbeliefpropagationcache.jl | 184 ++++++++---------- .../beliefpropagationcache.jl | 178 ++++++----------- .../beliefpropagationproblem.jl | 109 ++++++++--- 3 files changed, 226 insertions(+), 245 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 5eae283..8c6b3dd 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,117 +1,124 @@ -abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end +using Graphs: AbstractGraph, AbstractEdge +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent -#Interface -factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() -setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() -messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() -function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() -end -default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) - return not_implemented() +messages(::AbstractGraph) = not_implemented() +messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] + +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] + +deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() +function deletemessage!(bp_cache::AbstractDataGraph, edge) + ms = messages(bp_cache) + delete!(ms, edge) + return bp_cache end -function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() + +function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache end -function rescale_messages( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... - ) - return not_implemented() + +setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() +function setmessage!(bp_cache::AbstractDataGraph, edge, message) + ms = messages(bp_cache) + set!(ms, edge, message) + return bp_cache end -function rescale_vertices( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... - ) - return not_implemented() +function setmessage!(bp_cache::QuotientView, edge, message) + setmessages!(parent(bp_cache), QuotientEdge(edge), message) + return bp_cache end -function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return not_implemented() +function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) + for e in edges(bp_cache, edge) + setmessage!(parent(bp_cache), e, message[e]) + end + return bp_cache end -function edge_scalar( - bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... - ) - return not_implemented() +function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) + for e in edges + setmessage!(bpc_dst, e, message(bpc_src, e)) + end + return bpc_dst end -#Graph functionality needed -Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function NamedGraphs.GraphsExtensions.boundary_edges( - bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... - ) - return not_implemented() +factors(bpc::AbstractGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] +factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) + +factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] + +setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() +function setfactor!(bpc::AbstractDataGraph, vertex, factor) + fs = factors(bpc) + set!(fs, vertex, factor) + return bpc end -#Functions derived from the interface -function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) - for (e, m) in zip(edges) - setmessage!(bp_cache, e, m) - end - return +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) + return message(bp_cache, edge) * message(bp_cache, reverse(edge)) end -function deletemessages!( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) - ) - for e in edges - deletemessage!(bp_cache, e) - end - return bp_cache +function region_scalar(bp_cache::AbstractGraph, vertex) + + messages = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + + return reduce(*, messages) * reduce(*, state) end -function vertex_scalars( - bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... - ) - return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) +message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) + +function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) + return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars( - bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... - ) - return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) + return map(e -> region_scalar(bp_cache, e), edges) end -function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) +function scalar_factors_quotient(bp_cache::AbstractGraph) return vertex_scalars(bp_cache), edge_scalars(bp_cache) end -function incoming_messages( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] - ) - b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) +function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges return messages(bp_cache, b_edges) end -function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return incoming_messages(bp_cache, [vertex]; kwargs...) -end +default_messages(::AbstractGraph) = not_implemented() #Adapt interface for changing device -function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) - bp_cache = copy(bp_cache) +map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) +function map_messages!(f, bp_cache, es = edges(bp_cache)) for e in es setmessage!(bp_cache, e, f(message(bp_cache, e))) end return bp_cache end -function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) - bp_cache = copy(bp_cache) + +map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) +function map_factors!(f, bp_cache, vs = vertices(bp_cache)) for v in vs setfactor!(bp_cache, v, f(factor(bp_cache, v))) end return bp_cache end -function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_messages(adapt(to), bp_cache, args...) -end -function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_factors(adapt(to), bp_cache, args...) -end -function freenergy(bp_cache::AbstractBeliefPropagationCache) +adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) +adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) + +abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end + +function free_energy(bp_cache::AbstractBeliefPropagationCache) numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) @@ -123,29 +130,4 @@ function freenergy(bp_cache::AbstractBeliefPropagationCache) any(iszero, denominator_terms) && return -Inf return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end - -function partitionfunction(bp_cache::AbstractBeliefPropagationCache) - return exp(freenergy(bp_cache)) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return rescale_messages(bp_cache, [edge]) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache) - return rescale_messages(bp_cache, edges(bp_cache)) -end - -function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) - return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) -end - -function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) - return rescale_vertices(bpc, [vertex]; kwargs...) -end - -function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) - bpc = rescale_messages(bpc) - bpc = rescale_partitions(bpc, args...; kwargs...) - return bpc -end +partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index b3a32b1..4e441fb 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,145 +1,93 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using NamedGraphs: convert_vertextype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph using ITensorBase: ITensor, dim +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph -struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: - AbstractBeliefPropagationCache{V} +struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: + AbstractBeliefPropagationCache{V, MT} network::N - messages::Dictionary + messages::Dictionary{ET, MT} end -messages(bp_cache::BeliefPropagationCache) = bp_cache.messages -network(bp_cache::BeliefPropagationCache) = bp_cache.network +network(bp_cache) = underlying_graph(bp_cache) -BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) - -function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) +DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) +DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) +function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) + return fieldtype(type, :network) end -function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) - ms = messages(bp_cache) - delete!(ms, e) - return bp_cache -end +message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT -function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) - ms = messages(bp_cache) - set!(ms, e, message) - return bp_cache +function BeliefPropagationCache(alg, network::AbstractGraph) + es = collect(edges(network)) + es = vcat(es, reverse.(es)) + messages = map(edge -> default_message(alg, network, edge), es) + return BeliefPropagationCache(network, Dictionary(es, messages)) end -function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) - ms_dst = messages(bpc_dst) - for e in edges - set!(ms_dst, e, message(bpc_src, e)) - end - return bpc_dst +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) end -function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) - ms = messages(bp_cache) - return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +# TODO: This needs to go in DataGraphsGraphsExtensionsExt +# +# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges +# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, +# hence we just strip off any `AbstractDataGraph` data to avoid this. +function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) + return forest_cover_edge_sequence(underlying_graph(g); kwargs...) end - -function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) - return [message(bp_cache, e) for e in edges] +# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt +# +# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the +# data of a data graph to be removed using the above method if `parent_type(g)` is an +# `AbstractDataGraph`. +function forest_cover_edge_sequence(g::QuotientView; kwargs...) + return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) end - -#Forward onto the network -for f in [ - :(Graphs.vertices), - :(Graphs.edges), - :(Graphs.is_tree), - :(NamedGraphs.GraphsExtensions.boundary_edges), - :(factors), - :(default_bp_maxiter), - :(ITensorNetworksNext.setfactor!), - :(ITensorNetworksNext.linkinds), - :(ITensorNetworksNext.underlying_graph), - ] - @eval begin - function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) - return $f(network(bp_cache), args...; kwargs...) +# TODO: This needs to go in GraphsExtensions +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + add_edges!(g, edges(g)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) end end + return rv end -function factors(tn::AbstractTensorNetwork, vertex) - return [tn[vertex]] -end - -function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] -end - -function region_scalar(bp_cache::BeliefPropagationCache, vertex) - incoming_ms = incoming_messages(bp_cache, vertex) - state = factors(bp_cache, vertex) - return (reduce(*, incoming_ms) * reduce(*, state))[] -end - -function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return default_message(network(bp_cache), edge::AbstractEdge) -end - -function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) - t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) - #TODO: Get datatype working on tensornetworks so we can support GPU, etc... - return t -end - -#TODO: Update message etc should go here... -function updated_message( - alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge - ) - vertex = src(edge) - incoming_ms = incoming_messages( - bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] - ) - state = factors(bp_cache, vertex) - #contract_list = ITensor[incoming_ms; state] - #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) - #updated_messages = contract(contract_list; sequence) - updated_message = - !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) - if alg.normalize - message_norm = LinearAlgebra.norm(updated_message) - if !iszero(message_norm) - updated_message /= message_norm +function bpcache_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) + subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + for e in edges(subgraph) + if isassigned(graph, e) + set!(edge_data(subgraph), e, graph[e]) end end - return updated_message + return subgraph, vlist end -function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" - ) - return Algorithm("contract"; normalize, sequence_alg) +function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) + return bpcache_induced_subgraph(graph, subvertices) end -function default_algorithm( - ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") - ) - return Algorithm("adapt_update"; adapt, alg) +# For method ambiguity +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} + return bpcache_induced_subgraph(graph, subvertices) end -function update_message!( - message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge - ) - return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) -end +## PartitionedGraphs -#Edge sequence stuff -function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) - forests = forest_cover(g) - edges = edgetype(g)[] - for forest in forests - trees = [forest[vs] for vs in connected_components(forest)] - for tree in trees - tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) - push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) - end - end - return edges +function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) + qview = QuotientView(network(bpc)) + messages = edge_data(QuotientView(bpc)) + return BeliefPropagationCache(qview, messages) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 967b454..a05c97a 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,70 +1,121 @@ -mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: - AbstractProblem +using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers +using Graphs: SimpleGraph, vertices, edges, has_edge +using NamedGraphs: AbstractNamedGraph, position_graph +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices + +abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end + +mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} + const alg::Alg const cache::Cache diff::Union{Nothing, Float64} end +BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) + function default_algorithm( ::Type{<:Algorithm"bp"}, - bpc::BeliefPropagationCache; + bpc; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(network(bpc)), + edge_sequence = forest_cover_edge_sequence(bpc), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) end -function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) - prob = iter.problem +function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) + edges = prob.alg.edge_sequence - edge_group, kwargs = current_region_plan(iter) + plan = map(edges) do e + return e => (; sweep_kwargs...) + end - new_message_tensors = map(edge_group) do edge - old_message = message(prob.cache, edge) + return plan +end - new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) +function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) + prob = iter.problem - if !isnothing(prob.diff) - # TODO: Define `message_diff` - prob.diff += message_diff(new_message, old_message) - end + edge, _ = current_region_plan(iter) + new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) + setmessage!(prob.cache, edge, new_message) - return new_message - end + return iter +end - foreach(edge_group, new_message_tensors) do edge, new_message - setmessage!(prob.cache, edge, new_message) - end +default_message(alg, network, edge) = default_message(typeof(alg), network, edge) - return iter +default_message(::Type{<:Algorithm}, network, edge) = not_implemented() +function default_message(::Type{<:Algorithm"bp"}, network, edge) + + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + links = linkinds(network, edge) + data = ones(dim.(links)...) + + t = ITensor(data, links) + return t end -function region_plan( - prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... +updated_message(alg, bpc, edge) = not_implemented() +function updated_message(alg::Algorithm"contract", bpc, edge) + vertex = src(edge) + + incoming_ms = incoming_messages( + bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] ) - edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) + updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) - plan = map(edges) do e - return [e] => (; sweep_kwargs...) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end end + return updated_message +end - return plan +contract_messages(alg, factors, messages) = not_implemented() +function contract_messages( + alg, + factors::Vector{<:AbstractArray}, + messages::Vector{<:AbstractArray}, + ) + return contract_network(alg, vcat(factors, messages)) +end + +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ) + return Algorithm("contract"; normalize, contraction_alg) +end +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") + ) + return Algorithm("adapt_update"; adapt, alg) +end + +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end function update(bpc::AbstractBeliefPropagationCache; kwargs...) return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) end -function update(alg::Algorithm"bp", bpc) + +function update(alg, bpc) compute_error = !isnothing(alg.tol) diff = compute_error ? 0.0 : nothing - prob = BeliefPropagationProblem(bpc, diff) + prob = BeliefPropagationProblem(alg, bpc, diff) - iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + iter = SweepIterator(prob, alg.maxiter; compute_error) for _ in iter if compute_error && prob.diff <= alg.tol From a188ad9061032399b51cadc3d83542a5a8fc9296 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:47:19 -0500 Subject: [PATCH 11/15] Remove dead deps --- src/beliefpropagation/beliefpropagationproblem.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a05c97a..f487ccc 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,4 +1,3 @@ -using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers using Graphs: SimpleGraph, vertices, edges, has_edge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices From 5bc18db505950327a3b2e00d69dcd0bd17050c77 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:05:45 -0500 Subject: [PATCH 12/15] Fix merge --- src/beliefpropagation/beliefpropagationproblem.jl | 2 +- src/tensornetwork.jl | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index f487ccc..61c97df 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -87,7 +87,7 @@ function contract_messages( end function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") ) return Algorithm("contract"; normalize, contraction_alg) end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 11c2e88..44b883a 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -4,7 +4,7 @@ using Dictionaries: AbstractDictionary, Indices, dictionary using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, @@ -87,10 +87,7 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) - insert_trivial_link!(network, edge) - end - return network + return tn end # Determine the graph structure from the tensors. From 8aed2d9c8921dcb7881e8d4283fe6e879af23164 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:12:01 -0500 Subject: [PATCH 13/15] Fix type inference in TensorNetwork construction --- src/tensornetwork.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 44b883a..0681da5 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -66,8 +66,7 @@ end tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) - tensors = Dictionary(vertices(graph), f.(vertices(graph))) - return TensorNetwork(graph, tensors) + return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end function TensorNetwork(graph::AbstractGraph, tensors) tn = _TensorNetwork(graph, tensors) From 50819fb8b7eb4e6497d213cbb6156b7808fd0252 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:16:04 +0000 Subject: [PATCH 14/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_beliefpropagation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index fc657e7..a39e1a6 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -51,4 +51,3 @@ using Test: @test, @testset z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test abs(z_bp - z_exact) <= 1.0e-14 end - From 3573d9576138f34bb6e2683e615cb51aa14e8cdb Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 16:45:45 -0500 Subject: [PATCH 15/15] Remove `ITensorBase` dep --- Project.toml | 2 -- src/beliefpropagation/beliefpropagationcache.jl | 1 - src/beliefpropagation/beliefpropagationproblem.jl | 6 ++---- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index f892a74..85efef2 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -40,7 +39,6 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8" diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 4e441fb..5d8fa35 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -3,7 +3,6 @@ using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs: convert_vertextype using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph -using ITensorBase: ITensor, dim using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 61c97df..49d0ef8 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -52,10 +52,8 @@ function default_message(::Type{<:Algorithm"bp"}, network, edge) #TODO: Get datatype working on tensornetworks so we can support GPU, etc... links = linkinds(network, edge) - data = ones(dim.(links)...) - - t = ITensor(data, links) - return t + data = ones(Tuple(links)) + return data end updated_message(alg, bpc, edge) = not_implemented()