Skip to content

Commit 013b8b4

Browse files
author
Jack Dunham
committed
Express BP in terms of SweepIterator interface
Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling.
1 parent 0bbf584 commit 013b8b4

File tree

4 files changed

+101
-113
lines changed

4 files changed

+101
-113
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
1010
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
1111
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1212
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
13+
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1516
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
@@ -33,6 +34,7 @@ BackendSelection = "0.1.6"
3334
DataGraphs = "0.2.7"
3435
Dictionaries = "0.4.5"
3536
Graphs = "1.13.1"
37+
ITensorBase = "0.2.14"
3638
LinearAlgebra = "1.10"
3739
MacroTools = "0.5.16"
3840
NamedDimsArrays = "0.8"

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ include("iterators.jl")
99

1010
include("beliefpropagation/abstractbeliefpropagationcache.jl")
1111
include("beliefpropagation/beliefpropagationcache.jl")
12+
include("beliefpropagation/beliefpropagationproblem.jl")
1213

1314
end
Lines changed: 13 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
using DiagonalArrays: delta
21
using Dictionaries: Dictionary, set!, delete!
32
using Graphs: AbstractGraph, is_tree, connected_components
43
using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges
54
using ITensorBase: ITensor, dim
6-
using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype
75

86
struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <:
97
AbstractBeliefPropagationCache{V}
@@ -13,9 +11,8 @@ end
1311

1412
messages(bp_cache::BeliefPropagationCache) = bp_cache.messages
1513
network(bp_cache::BeliefPropagationCache) = bp_cache.network
16-
default_messages() = Dictionary()
1714

18-
BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages())
15+
BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary())
1916

2017
function Base.copy(bp_cache::BeliefPropagationCache)
2118
return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache)))
@@ -33,16 +30,15 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message)
3330
return bp_cache
3431
end
3532

36-
function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...)
33+
function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...)
3734
ms = messages(bp_cache)
3835
return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge)
3936
end
4037

41-
function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge})
38+
function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge})
4239
return [message(bp_cache, e) for e in edges]
4340
end
4441

45-
default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing
4642
#Forward onto the network
4743
for f in [
4844
:(Graphs.vertices),
@@ -62,11 +58,6 @@ for f in [
6258
end
6359
end
6460

65-
#TODO: Get subgraph working on an ITensorNetwork to overload this directly
66-
function default_bp_edge_sequence(bp_cache::BeliefPropagationCache)
67-
return forest_cover_edge_sequence(underlying_graph(bp_cache))
68-
end
69-
7061
function factors(tn::AbstractTensorNetwork, vertex)
7162
return [tn[vertex]]
7263
end
@@ -91,33 +82,6 @@ function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge)
9182
return t
9283
end
9384

94-
#Algorithmic defaults
95-
default_update_alg(bp_cache::BeliefPropagationCache) = "bp"
96-
default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract"
97-
default_normalize(::Algorithm"contract") = true
98-
default_sequence_alg(::Algorithm"contract") = "optimal"
99-
function set_default_kwargs(alg::Algorithm"contract")
100-
normalize = get(alg, :normalize, default_normalize(alg))
101-
sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg))
102-
return Algorithm("contract"; normalize, sequence_alg)
103-
end
104-
function set_default_kwargs(alg::Algorithm"adapt_update")
105-
_alg = set_default_kwargs(get(alg, :alg, Algorithm("contract")))
106-
return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg)
107-
end
108-
default_verbose(::Algorithm"bp") = false
109-
default_tol(::Algorithm"bp") = nothing
110-
function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
111-
verbose = get(alg, :verbose, default_verbose(alg))
112-
maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache))
113-
edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache))
114-
tol = get(alg, :tol, default_tol(alg))
115-
message_update_alg = set_default_kwargs(
116-
get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache)))
117-
)
118-
return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg)
119-
end
120-
12185
#TODO: Update message etc should go here...
12286
function updated_message(
12387
alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge
@@ -141,85 +105,21 @@ function updated_message(
141105
return updated_message
142106
end
143107

144-
function updated_message(
145-
bp_cache::BeliefPropagationCache,
146-
edge::AbstractEdge;
147-
alg = default_message_update_alg(bpc),
148-
kwargs...,
108+
function default_algorithm(
109+
::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal"
149110
)
150-
return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge)
111+
return Algorithm("contract"; normalize, sequence_alg)
151112
end
152-
153-
function update_message!(
154-
message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge
113+
function default_algorithm(
114+
::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract")
155115
)
156-
return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge))
116+
return Algorithm("adapt_update"; adapt, alg)
157117
end
158118

159-
"""
160-
Do a sequential update of the message tensors on `edges`
161-
"""
162-
function update_iteration(
163-
alg::Algorithm"bp",
164-
bpc::AbstractBeliefPropagationCache,
165-
edges::Vector;
166-
(update_diff!) = nothing,
167-
)
168-
bpc = copy(bpc)
169-
for e in edges
170-
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
171-
update_message!(alg.message_update_alg, bpc, e)
172-
if !isnothing(update_diff!)
173-
update_diff![] += message_diff(message(bpc, e), prev_message)
174-
end
175-
end
176-
return bpc
177-
end
178-
179-
"""
180-
Do parallel updates between groups of edges of all message tensors
181-
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
182-
mts relevant to that group.
183-
"""
184-
function update_iteration(
185-
alg::Algorithm"bp",
186-
bpc::AbstractBeliefPropagationCache,
187-
edge_groups::Vector{<:Vector{<:AbstractEdge}};
188-
(update_diff!) = nothing,
119+
function update_message!(
120+
message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge
189121
)
190-
new_mts = empty(messages(bpc))
191-
for edges in edge_groups
192-
bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!))
193-
for e in edges
194-
set!(new_mts, e, message(bpc_t, e))
195-
end
196-
end
197-
return set_messages(bpc, new_mts)
198-
end
199-
200-
"""
201-
More generic interface for update, with default params
202-
"""
203-
function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache)
204-
compute_error = !isnothing(alg.tol)
205-
if isnothing(alg.maxiter)
206-
error("You need to specify a number of iterations for BP!")
207-
end
208-
for i in 1:alg.maxiter
209-
diff = compute_error ? Ref(0.0) : nothing
210-
bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff)
211-
if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol
212-
if alg.verbose
213-
println("BP converged to desired precision after $i iterations.")
214-
end
215-
break
216-
end
217-
end
218-
return bpc
219-
end
220-
221-
function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...)
222-
return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc)
122+
return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge))
223123
end
224124

225125
#Edge sequence stuff
@@ -234,4 +134,4 @@ function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root
234134
end
235135
end
236136
return edges
237-
end
137+
end
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <:
2+
AbstractProblem
3+
const cache::Cache
4+
diff::Union{Nothing, Float64}
5+
end
6+
7+
function default_algorithm(
8+
::Type{<:Algorithm"bp"},
9+
bpc::BeliefPropagationCache;
10+
verbose = false,
11+
tol = nothing,
12+
edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)),
13+
message_update_alg = default_algorithm(Algorithm"contract"),
14+
maxiter = is_tree(bpc) ? 1 : nothing,
15+
)
16+
return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter)
17+
end
18+
19+
function compute!(iter::RegionIterator{<:BeliefPropagationProblem})
20+
prob = iter.problem
21+
22+
edge_group, kwargs = current_region_plan(iter)
23+
24+
new_message_tensors = map(edge_group) do edge
25+
old_message = message(prob.cache, edge)
26+
27+
new_message = updated_message(kwargs.message_update_alg, prob.cache, edge)
28+
29+
if !isnothing(prob.diff)
30+
# TODO: Define `message_diff`
31+
prob.diff += message_diff(new_message, old_message)
32+
end
33+
34+
return new_message
35+
end
36+
37+
foreach(edge_group, new_message_tensors) do edge, new_message
38+
setmessage!(prob.cache, edge, new_message)
39+
end
40+
41+
return iter
42+
end
43+
44+
function region_plan(
45+
prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs...
46+
)
47+
edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex)
48+
49+
plan = map(edges) do e
50+
return [e] => (; sweep_kwargs...)
51+
end
52+
53+
return plan
54+
end
55+
56+
function update(bpc::AbstractBeliefPropagationCache; kwargs...)
57+
return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc)
58+
end
59+
function update(alg::Algorithm"bp", bpc)
60+
compute_error = !isnothing(alg.tol)
61+
62+
diff = compute_error ? 0.0 : nothing
63+
64+
prob = BeliefPropagationProblem(bpc, diff)
65+
66+
iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...)
67+
68+
for _ in iter
69+
if compute_error && prob.diff <= alg.tol
70+
break
71+
end
72+
end
73+
74+
if alg.verbose && compute_error
75+
if prob.diff <= alg.tol
76+
println("BP converged to desired precision after $(iter.which_sweep) iterations.")
77+
else
78+
println(
79+
"BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations",
80+
)
81+
end
82+
end
83+
84+
return bpc
85+
end

0 commit comments

Comments
 (0)