Skip to content

Commit 8bd8c35

Browse files
author
Jack Dunham
committed
BP Caching overhauls
1 parent 8f93ec8 commit 8bd8c35

File tree

3 files changed

+226
-245
lines changed

3 files changed

+226
-245
lines changed
Lines changed: 83 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,124 @@
1-
abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end
1+
using Graphs: AbstractGraph, AbstractEdge
2+
using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype
3+
using NamedGraphs.GraphsExtensions: boundary_edges
4+
using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent
25

3-
#Interface
4-
factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented()
5-
setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented()
6-
messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
7-
message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented()
8-
function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge)
9-
return not_implemented()
10-
end
11-
default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
12-
function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message)
13-
return not_implemented()
6+
messages(::AbstractGraph) = not_implemented()
7+
messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache)
8+
messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges]
9+
10+
message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge]
11+
12+
deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented()
13+
function deletemessage!(bp_cache::AbstractDataGraph, edge)
14+
ms = messages(bp_cache)
15+
delete!(ms, edge)
16+
return bp_cache
1417
end
15-
function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge)
16-
return not_implemented()
18+
19+
function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache))
20+
for e in edges
21+
deletemessage!(bp_cache, e)
22+
end
23+
return bp_cache
1724
end
18-
function rescale_messages(
19-
bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs...
20-
)
21-
return not_implemented()
25+
26+
setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented()
27+
function setmessage!(bp_cache::AbstractDataGraph, edge, message)
28+
ms = messages(bp_cache)
29+
set!(ms, edge, message)
30+
return bp_cache
2231
end
23-
function rescale_vertices(
24-
bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs...
25-
)
26-
return not_implemented()
32+
function setmessage!(bp_cache::QuotientView, edge, message)
33+
setmessages!(parent(bp_cache), QuotientEdge(edge), message)
34+
return bp_cache
2735
end
2836

29-
function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...)
30-
return not_implemented()
37+
function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message)
38+
for e in edges(bp_cache, edge)
39+
setmessage!(parent(bp_cache), e, message[e])
40+
end
41+
return bp_cache
3142
end
32-
function edge_scalar(
33-
bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...
34-
)
35-
return not_implemented()
43+
function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges)
44+
for e in edges
45+
setmessage!(bpc_dst, e, message(bpc_src, e))
46+
end
47+
return bpc_dst
3648
end
3749

38-
#Graph functionality needed
39-
Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
40-
Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
41-
function NamedGraphs.GraphsExtensions.boundary_edges(
42-
bp_cache::AbstractBeliefPropagationCache, vertices; kwargs...
43-
)
44-
return not_implemented()
50+
factors(bpc::AbstractGraph) = vertex_data(bpc)
51+
factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices]
52+
factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex])
53+
54+
factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex]
55+
56+
setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented()
57+
function setfactor!(bpc::AbstractDataGraph, vertex, factor)
58+
fs = factors(bpc)
59+
set!(fs, vertex, factor)
60+
return bpc
4561
end
4662

47-
#Functions derived from the interface
48-
function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages)
49-
for (e, m) in zip(edges)
50-
setmessage!(bp_cache, e, m)
51-
end
52-
return
63+
function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge)
64+
return message(bp_cache, edge) * message(bp_cache, reverse(edge))
5365
end
5466

55-
function deletemessages!(
56-
bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache)
57-
)
58-
for e in edges
59-
deletemessage!(bp_cache, e)
60-
end
61-
return bp_cache
67+
function region_scalar(bp_cache::AbstractGraph, vertex)
68+
69+
messages = incoming_messages(bp_cache, vertex)
70+
state = factors(bp_cache, vertex)
71+
72+
return reduce(*, messages) * reduce(*, state)
6273
end
6374

64-
function vertex_scalars(
65-
bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs...
66-
)
67-
return map(v -> region_scalar(bp_cache, v; kwargs...), vertices)
75+
message_type(bpc::AbstractGraph) = message_type(typeof(bpc))
76+
message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G))
77+
message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type)
78+
79+
function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache))
80+
return map(v -> region_scalar(bp_cache, v), vertices)
6881
end
6982

70-
function edge_scalars(
71-
bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs...
72-
)
73-
return map(e -> region_scalar(bp_cache, e; kwargs...), edges)
83+
function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache))
84+
return map(e -> region_scalar(bp_cache, e), edges)
7485
end
7586

76-
function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache)
87+
function scalar_factors_quotient(bp_cache::AbstractGraph)
7788
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
7889
end
7990

80-
function incoming_messages(
81-
bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = []
82-
)
83-
b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in)
91+
function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = [])
92+
b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in)
8493
b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges
8594
return messages(bp_cache, b_edges)
8695
end
8796

88-
function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...)
89-
return incoming_messages(bp_cache, [vertex]; kwargs...)
90-
end
97+
default_messages(::AbstractGraph) = not_implemented()
9198

9299
#Adapt interface for changing device
93-
function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache))
94-
bp_cache = copy(bp_cache)
100+
map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es)
101+
function map_messages!(f, bp_cache, es = edges(bp_cache))
95102
for e in es
96103
setmessage!(bp_cache, e, f(message(bp_cache, e)))
97104
end
98105
return bp_cache
99106
end
100-
function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache))
101-
bp_cache = copy(bp_cache)
107+
108+
map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs)
109+
function map_factors!(f, bp_cache, vs = vertices(bp_cache))
102110
for v in vs
103111
setfactor!(bp_cache, v, f(factor(bp_cache, v)))
104112
end
105113
return bp_cache
106114
end
107-
function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...)
108-
return map_messages(adapt(to), bp_cache, args...)
109-
end
110-
function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...)
111-
return map_factors(adapt(to), bp_cache, args...)
112-
end
113115

114-
function freenergy(bp_cache::AbstractBeliefPropagationCache)
116+
adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es)
117+
adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs)
118+
119+
abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end
120+
121+
function free_energy(bp_cache::AbstractBeliefPropagationCache)
115122
numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache)
116123
if any(t -> real(t) < 0, numerator_terms)
117124
numerator_terms = complex.(numerator_terms)
@@ -123,29 +130,4 @@ function freenergy(bp_cache::AbstractBeliefPropagationCache)
123130
any(iszero, denominator_terms) && return -Inf
124131
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
125132
end
126-
127-
function partitionfunction(bp_cache::AbstractBeliefPropagationCache)
128-
return exp(freenergy(bp_cache))
129-
end
130-
131-
function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge)
132-
return rescale_messages(bp_cache, [edge])
133-
end
134-
135-
function rescale_messages(bp_cache::AbstractBeliefPropagationCache)
136-
return rescale_messages(bp_cache, edges(bp_cache))
137-
end
138-
139-
function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...)
140-
return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...)
141-
end
142-
143-
function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...)
144-
return rescale_vertices(bpc, [vertex]; kwargs...)
145-
end
146-
147-
function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
148-
bpc = rescale_messages(bpc)
149-
bpc = rescale_partitions(bpc, args...; kwargs...)
150-
return bpc
151-
end
133+
partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache))

0 commit comments

Comments
 (0)