Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ include("abstract_problem.jl")
include("iterators.jl")
include("adapters.jl")

include("beliefpropagation/abstractbeliefpropagationcache.jl")
include("beliefpropagation/beliefpropagationcache.jl")
include("beliefpropagation/beliefpropagationproblem.jl")

end
104 changes: 53 additions & 51 deletions src/abstracttensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize
using MacroTools: @capture
using NamedDimsArrays: dimnames, inds
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!,
rename_vertices, vertextype
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
using NamedGraphs.GraphsExtensions:
⊔,
directed_graph,
incident_edges,
rem_edges!,
rename_vertices,
vertextype
using SplitApplyCombine: flatten
using NamedGraphs.SimilarType: similar_type

abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end

function Graphs.rem_edge!(tn::AbstractTensorNetwork, e)
rem_edge!(underlying_graph(tn), e)
return tn
end
# Need to be careful about removing edges from tensor networks in case there is a bond
Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented()

# TODO: Define a generic fallback for `AbstractDataGraph`?
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data")
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented()

# Graphs.jl overloads
function Graphs.weights(graph::AbstractTensorNetwork)
Expand All @@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork)
end

# Copy
Base.copy(tn::AbstractTensorNetwork) = error("Not implemented")
Base.copy(::AbstractTensorNetwork) = not_implemented()

# Iteration
Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...)
Expand All @@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn))
# Overload if needed
Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false

# Derived interface, may need to be overloaded
function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork})
return underlying_graph_type(data_graph_type(G))
end

# AbstractDataGraphs overloads
function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...)
return error("Not implemented")
end
function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...)
return error("Not implemented")
end
DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented()
DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented()

DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented")
DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented()
function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork)
return NamedGraphs.vertex_positions(underlying_graph(tn))
end
Expand All @@ -81,49 +76,46 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork)
return map_vertex_data_preserve_graph(adapt(to), tn)
end

function linkinds(tn::AbstractTensorNetwork, edge::Pair)
return linkinds(tn, edgetype(tn)(edge))
end
function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge)
return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)])
end
function linkaxes(tn::AbstractTensorNetwork, edge::Pair)
linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge))
linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)])

function linkaxes(tn::AbstractGraph, edge::Pair)
return linkaxes(tn, edgetype(tn)(edge))
end
function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
function linkaxes(tn::AbstractGraph, edge::AbstractEdge)
return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)])
end
function linknames(tn::AbstractTensorNetwork, edge::Pair)
function linknames(tn::AbstractGraph, edge::Pair)
return linknames(tn, edgetype(tn)(edge))
end
function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge)
function linknames(tn::AbstractGraph, edge::AbstractEdge)
return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)])
end

function siteinds(tn::AbstractTensorNetwork, v)
function siteinds(tn::AbstractGraph, v)
s = inds(tn[v])
for v′ in neighbors(tn, v)
s = setdiff(s, inds(tn[v′]))
end
return s
end
function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
function siteaxes(tn::AbstractGraph, edge::AbstractEdge)
s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)])
for v′ in neighbors(tn, v)
s = setdiff(s, axes(tn[v′]))
end
return s
end
function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge)
function sitenames(tn::AbstractGraph, edge::AbstractEdge)
s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)])
for v′ in neighbors(tn, v)
s = setdiff(s, dimnames(tn[v′]))
end
return s
end

function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex)
vertex_data(tn)[vertex] = value
function setindex_preserve_graph!(tn::AbstractGraph, value, vertex)
set!(vertex_data(tn), vertex, value)
return tn
end

Expand Down Expand Up @@ -153,15 +145,15 @@ end

# Update the graph of the TensorNetwork `tn` to include
# edges that should exist based on the tensor connectivity.
function add_missing_edges!(tn::AbstractTensorNetwork)
function add_missing_edges!(tn::AbstractGraph)
foreach(v -> add_missing_edges!(tn, v), vertices(tn))
return tn
end

# Update the graph of the TensorNetwork `tn` to include
# edges that should be incident to the vertex `v`
# based on the tensor connectivity.
function add_missing_edges!(tn::AbstractTensorNetwork, v)
function add_missing_edges!(tn::AbstractGraph, v)
for v′ in vertices(tn)
if v ≠ v′
e = v => v′
Expand All @@ -175,13 +167,13 @@ end

# Fix the edges of the TensorNetwork `tn` to match
# the tensor connectivity.
function fix_edges!(tn::AbstractTensorNetwork)
function fix_edges!(tn::AbstractGraph)
foreach(v -> fix_edges!(tn, v), vertices(tn))
return tn
end
# Fix the edges of the TensorNetwork `tn` to match
# the tensor connectivity at vertex `v`.
function fix_edges!(tn::AbstractTensorNetwork, v)
function fix_edges!(tn::AbstractGraph, v)
rem_edges!(tn, incident_edges(tn, v))
add_missing_edges!(tn, v)
return tn
Expand Down Expand Up @@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v)
fix_edges!(tn, v)
return tn
end
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
# Fix ambiguity error.
function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger)
graph[vertices(graph)[vertex]] = value
return graph
end
# Fix ambiguity error.
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge)
return error("No edge data.")
end
# Fix ambiguity error.
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair)
return error("No edge data.")
end
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented()
Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented()
# Fix ambiguity error.
function Base.setindex!(
tn::AbstractTensorNetwork,
value,
edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger},
)
return error("No edge data.")
return not_implemented()
end

function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
Expand All @@ -255,3 +239,21 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
end

Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)

function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int}
return tensornetwork_induced_subgraph(graph, subvertices)
end
function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices)
return tensornetwork_induced_subgraph(graph, subvertices)
end

function tensornetwork_induced_subgraph(graph, subvertices)
underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices)
subgraph = similar_type(graph)(underlying_subgraph)
for v in vertices(subgraph)
if isassigned(graph, v)
set!(vertex_data(subgraph), v, graph[v])
end
end
return subgraph, vlist
end
133 changes: 133 additions & 0 deletions src/beliefpropagation/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using Graphs: AbstractGraph, AbstractEdge
using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype
using NamedGraphs.GraphsExtensions: boundary_edges
using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent

messages(::AbstractGraph) = not_implemented()
messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache)
messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges]

message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge]

deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented()
function deletemessage!(bp_cache::AbstractDataGraph, edge)
ms = messages(bp_cache)
delete!(ms, edge)
return bp_cache
end

function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache))
for e in edges
deletemessage!(bp_cache, e)
end
return bp_cache
end

setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented()
function setmessage!(bp_cache::AbstractDataGraph, edge, message)
ms = messages(bp_cache)
set!(ms, edge, message)
return bp_cache
end
function setmessage!(bp_cache::QuotientView, edge, message)
setmessages!(parent(bp_cache), QuotientEdge(edge), message)
return bp_cache
end

function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message)
for e in edges(bp_cache, edge)
setmessage!(parent(bp_cache), e, message[e])
end
return bp_cache
end
function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges)
for e in edges
setmessage!(bpc_dst, e, message(bpc_src, e))
end
return bpc_dst
end

factors(bpc::AbstractGraph) = vertex_data(bpc)
factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices]
factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex])

factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex]

setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented()
function setfactor!(bpc::AbstractDataGraph, vertex, factor)
fs = factors(bpc)
set!(fs, vertex, factor)
return bpc
end

function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge)
return message(bp_cache, edge) * message(bp_cache, reverse(edge))
end

function region_scalar(bp_cache::AbstractGraph, vertex)

messages = incoming_messages(bp_cache, vertex)
state = factors(bp_cache, vertex)

return reduce(*, messages) * reduce(*, state)
end

message_type(bpc::AbstractGraph) = message_type(typeof(bpc))
message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G))
message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type)

function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache))
return map(v -> region_scalar(bp_cache, v), vertices)
end

function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache))
return map(e -> region_scalar(bp_cache, e), edges)
end

function scalar_factors_quotient(bp_cache::AbstractGraph)
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
end

function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = [])
b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in)
b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges
return messages(bp_cache, b_edges)
end

default_messages(::AbstractGraph) = not_implemented()

#Adapt interface for changing device
map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es)
function map_messages!(f, bp_cache, es = edges(bp_cache))
for e in es
setmessage!(bp_cache, e, f(message(bp_cache, e)))
end
return bp_cache
end

map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs)
function map_factors!(f, bp_cache, vs = vertices(bp_cache))
for v in vs
setfactor!(bp_cache, v, f(factor(bp_cache, v)))
end
return bp_cache
end

adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es)
adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs)

abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end

function free_energy(bp_cache::AbstractBeliefPropagationCache)
numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache)
if any(t -> real(t) < 0, numerator_terms)
numerator_terms = complex.(numerator_terms)
end
if any(t -> real(t) < 0, denominator_terms)
denominator_terms = complex.(denominator_terms)
end

any(iszero, denominator_terms) && return -Inf
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
end
partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache))
Loading
Loading