Skip to content

Commit 4402c50

Browse files
committed
BP Code
1 parent eeac23b commit 4402c50

File tree

3 files changed

+407
-1
lines changed

3 files changed

+407
-1
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end
2+
3+
#Interface
4+
factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented()
5+
setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented()
6+
messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
7+
message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented()
8+
function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge)
9+
return not_implemented()
10+
end
11+
default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
12+
function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message)
13+
return not_implemented()
14+
end
15+
function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge)
16+
return not_implemented()
17+
end
18+
function rescale_messages(
19+
bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs...
20+
)
21+
return not_implemented()
22+
end
23+
function rescale_vertices(
24+
bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs...
25+
)
26+
return not_implemented()
27+
end
28+
29+
function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...)
30+
return not_implemented()
31+
end
32+
function edge_scalar(
33+
bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...
34+
)
35+
return not_implemented()
36+
end
37+
38+
#Graph functionality needed
39+
Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
40+
Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
41+
function NamedGraphs.GraphsExtensions.boundary_edges(
42+
bp_cache::AbstractBeliefPropagationCache, vertices; kwargs...
43+
)
44+
return not_implemented()
45+
end
46+
47+
#Functions derived from the interface
48+
function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages)
49+
for (e, m) in zip(edges)
50+
setmessage!(bp_cache, e, m)
51+
end
52+
return
53+
end
54+
55+
function deletemessages!(
56+
bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache)
57+
)
58+
for e in edges
59+
deletemessage!(bp_cache, e)
60+
end
61+
return bp_cache
62+
end
63+
64+
function vertex_scalars(
65+
bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs...
66+
)
67+
return map(v -> region_scalar(bp_cache, v; kwargs...), vertices)
68+
end
69+
70+
function edge_scalars(
71+
bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs...
72+
)
73+
return map(e -> region_scalar(bp_cache, e; kwargs...), edges)
74+
end
75+
76+
function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache)
77+
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
78+
end
79+
80+
function incoming_messages(
81+
bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = []
82+
)
83+
b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in)
84+
b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges
85+
return messages(bp_cache, b_edges)
86+
end
87+
88+
function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...)
89+
return incoming_messages(bp_cache, [vertex]; kwargs...)
90+
end
91+
92+
#Adapt interface for changing device
93+
function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache))
94+
bp_cache = copy(bp_cache)
95+
for e in es
96+
setmessage!(bp_cache, e, f(message(bp_cache, e)))
97+
end
98+
return bp_cache
99+
end
100+
function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache))
101+
bp_cache = copy(bp_cache)
102+
for v in vs
103+
setfactor!(bp_cache, v, f(factor(bp_cache, v)))
104+
end
105+
return bp_cache
106+
end
107+
function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...)
108+
return map_messages(adapt(to), bp_cache, args...)
109+
end
110+
function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...)
111+
return map_factors(adapt(to), bp_cache, args...)
112+
end
113+
114+
function freenergy(bp_cache::AbstractBeliefPropagationCache)
115+
numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache)
116+
if any(t -> real(t) < 0, numerator_terms)
117+
numerator_terms = complex.(numerator_terms)
118+
end
119+
if any(t -> real(t) < 0, denominator_terms)
120+
denominator_terms = complex.(denominator_terms)
121+
end
122+
123+
any(iszero, denominator_terms) && return -Inf
124+
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
125+
end
126+
127+
function partitionfunction(bp_cache::AbstractBeliefPropagationCache)
128+
return exp(freenergy(bp_cache))
129+
end
130+
131+
function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge)
132+
return rescale_messages(bp_cache, [edge])
133+
end
134+
135+
function rescale_messages(bp_cache::AbstractBeliefPropagationCache)
136+
return rescale_messages(bp_cache, edges(bp_cache))
137+
end
138+
139+
function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...)
140+
return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...)
141+
end
142+
143+
function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...)
144+
return rescale_vertices(bpc, [vertex]; kwargs...)
145+
end
146+
147+
function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
148+
bpc = rescale_messages(bpc)
149+
bpc = rescale_partitions(bpc, args...; kwargs...)
150+
return bpc
151+
end
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
using DiagonalArrays: delta
2+
using Dictionaries: Dictionary, set!, delete!
3+
using Graphs: AbstractGraph, is_tree, connected_components
4+
using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges
5+
using ITensorBase: ITensor, dim
6+
using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype
7+
8+
struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <:
9+
AbstractBeliefPropagationCache{V}
10+
network::N
11+
messages::Dictionary
12+
end
13+
14+
messages(bp_cache::BeliefPropagationCache) = bp_cache.messages
15+
network(bp_cache::BeliefPropagationCache) = bp_cache.network
16+
default_messages() = Dictionary()
17+
18+
BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages())
19+
20+
function Base.copy(bp_cache::BeliefPropagationCache)
21+
return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache)))
22+
end
23+
24+
function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge)
25+
ms = messages(bp_cache)
26+
delete!(ms, e)
27+
return bp_cache
28+
end
29+
30+
function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message)
31+
ms = messages(bp_cache)
32+
set!(ms, e, message)
33+
return bp_cache
34+
end
35+
36+
function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...)
37+
ms = messages(bp_cache)
38+
return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge)
39+
end
40+
41+
function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge})
42+
return [message(bp_cache, e) for e in edges]
43+
end
44+
45+
default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing
46+
#Forward onto the network
47+
for f in [
48+
:(Graphs.vertices),
49+
:(Graphs.edges),
50+
:(Graphs.is_tree),
51+
:(NamedGraphs.GraphsExtensions.boundary_edges),
52+
:(factors),
53+
:(default_bp_maxiter),
54+
:(ITensorNetworksNext.setfactor!),
55+
:(ITensorNetworksNext.linkinds),
56+
:(ITensorNetworksNext.underlying_graph),
57+
]
58+
@eval begin
59+
function $f(bp_cache::BeliefPropagationCache, args...; kwargs...)
60+
return $f(network(bp_cache), args...; kwargs...)
61+
end
62+
end
63+
end
64+
65+
#TODO: Get subgraph working on an ITensorNetwork to overload this directly
66+
function default_bp_edge_sequence(bp_cache::BeliefPropagationCache)
67+
return forest_cover_edge_sequence(underlying_graph(bp_cache))
68+
end
69+
70+
function factors(tn::AbstractTensorNetwork, vertex)
71+
return [tn[vertex]]
72+
end
73+
74+
function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge)
75+
return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[]
76+
end
77+
78+
function region_scalar(bp_cache::BeliefPropagationCache, vertex)
79+
incoming_ms = incoming_messages(bp_cache, vertex)
80+
state = factors(bp_cache, vertex)
81+
return (reduce(*, incoming_ms) * reduce(*, state))[]
82+
end
83+
84+
function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge)
85+
return default_message(network(bp_cache), edge::AbstractEdge)
86+
end
87+
88+
function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge)
89+
t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...)
90+
#TODO: Get datatype working on tensornetworks so we can support GPU, etc...
91+
return t
92+
end
93+
94+
#Algorithmic defaults
95+
default_update_alg(bp_cache::BeliefPropagationCache) = "bp"
96+
default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract"
97+
default_normalize(::Algorithm"contract") = true
98+
default_sequence_alg(::Algorithm"contract") = "optimal"
99+
function set_default_kwargs(alg::Algorithm"contract")
100+
normalize = get(alg, :normalize, default_normalize(alg))
101+
sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg))
102+
return Algorithm("contract"; normalize, sequence_alg)
103+
end
104+
function set_default_kwargs(alg::Algorithm"adapt_update")
105+
_alg = set_default_kwargs(get(alg, :alg, Algorithm("contract")))
106+
return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg)
107+
end
108+
default_verbose(::Algorithm"bp") = false
109+
default_tol(::Algorithm"bp") = nothing
110+
function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
111+
verbose = get(alg, :verbose, default_verbose(alg))
112+
maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache))
113+
edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache))
114+
tol = get(alg, :tol, default_tol(alg))
115+
message_update_alg = set_default_kwargs(
116+
get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache)))
117+
)
118+
return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg)
119+
end
120+
121+
#TODO: Update message etc should go here...
122+
function updated_message(
123+
alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge
124+
)
125+
vertex = src(edge)
126+
incoming_ms = incoming_messages(
127+
bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]
128+
)
129+
state = factors(bp_cache, vertex)
130+
#contract_list = ITensor[incoming_ms; state]
131+
#sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg)
132+
#updated_messages = contract(contract_list; sequence)
133+
updated_message =
134+
!isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state)
135+
if alg.normalize
136+
message_norm = LinearAlgebra.norm(updated_message)
137+
if !iszero(message_norm)
138+
updated_message /= message_norm
139+
end
140+
end
141+
return updated_message
142+
end
143+
144+
function updated_message(
145+
bp_cache::BeliefPropagationCache,
146+
edge::AbstractEdge;
147+
alg = default_message_update_alg(bpc),
148+
kwargs...,
149+
)
150+
return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge)
151+
end
152+
153+
function update_message!(
154+
message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge
155+
)
156+
return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge))
157+
end
158+
159+
"""
160+
Do a sequential update of the message tensors on `edges`
161+
"""
162+
function update_iteration(
163+
alg::Algorithm"bp",
164+
bpc::AbstractBeliefPropagationCache,
165+
edges::Vector;
166+
(update_diff!) = nothing,
167+
)
168+
bpc = copy(bpc)
169+
for e in edges
170+
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
171+
update_message!(alg.message_update_alg, bpc, e)
172+
if !isnothing(update_diff!)
173+
update_diff![] += message_diff(message(bpc, e), prev_message)
174+
end
175+
end
176+
return bpc
177+
end
178+
179+
"""
180+
Do parallel updates between groups of edges of all message tensors
181+
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
182+
mts relevant to that group.
183+
"""
184+
function update_iteration(
185+
alg::Algorithm"bp",
186+
bpc::AbstractBeliefPropagationCache,
187+
edge_groups::Vector{<:Vector{<:AbstractEdge}};
188+
(update_diff!) = nothing,
189+
)
190+
new_mts = empty(messages(bpc))
191+
for edges in edge_groups
192+
bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!))
193+
for e in edges
194+
set!(new_mts, e, message(bpc_t, e))
195+
end
196+
end
197+
return set_messages(bpc, new_mts)
198+
end
199+
200+
"""
201+
More generic interface for update, with default params
202+
"""
203+
function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache)
204+
compute_error = !isnothing(alg.tol)
205+
if isnothing(alg.maxiter)
206+
error("You need to specify a number of iterations for BP!")
207+
end
208+
for i in 1:alg.maxiter
209+
diff = compute_error ? Ref(0.0) : nothing
210+
bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff)
211+
if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol
212+
if alg.verbose
213+
println("BP converged to desired precision after $i iterations.")
214+
end
215+
break
216+
end
217+
end
218+
return bpc
219+
end
220+
221+
function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...)
222+
return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc)
223+
end
224+
225+
#Edge sequence stuff
226+
function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex)
227+
forests = forest_cover(g)
228+
edges = edgetype(g)[]
229+
for forest in forests
230+
trees = [forest[vs] for vs in connected_components(forest)]
231+
for tree in trees
232+
tree_edges = post_order_dfs_edges(tree, root_vertex(tree))
233+
push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...)
234+
end
235+
end
236+
return edges
237+
end

0 commit comments

Comments
 (0)