1- abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end
1+ using Graphs: AbstractGraph, AbstractEdge
2+ using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype
3+ using NamedGraphs. GraphsExtensions: boundary_edges
4+ using NamedGraphs. PartitionedGraphs: QuotientView, QuotientEdge, parent
25
3- # Interface
4- factor (bp_cache:: AbstractBeliefPropagationCache , vertex ) = not_implemented ( )
5- setfactor! (bp_cache:: AbstractBeliefPropagationCache , vertex, factor ) = not_implemented ()
6- messages (bp_cache :: AbstractBeliefPropagationCache ) = not_implemented ()
7- message (bp_cache:: AbstractBeliefPropagationCache , edge:: AbstractEdge ) = not_implemented ()
8- function default_message (bp_cache :: AbstractBeliefPropagationCache , edge :: AbstractEdge )
9- return not_implemented ()
10- end
11- default_messages (bp_cache :: AbstractBeliefPropagationCache ) = not_implemented ( )
12- function setmessage! (bp_cache :: AbstractBeliefPropagationCache , edge:: AbstractEdge , message )
13- return not_implemented ()
6+ messages ( :: AbstractGraph ) = not_implemented ()
7+ messages (bp_cache:: AbstractDataGraph ) = edge_data (bp_cache )
8+ messages (bp_cache:: AbstractGraph , edges ) = [ message (bp_cache, e) for e in edges]
9+
10+ message (bp_cache:: AbstractGraph , edge:: AbstractEdge ) = messages (bp_cache)[edge]
11+
12+ deletemessage! (bp_cache :: AbstractGraph , edge) = not_implemented ()
13+ function deletemessage! (bp_cache :: AbstractDataGraph , edge)
14+ ms = messages (bp_cache )
15+ delete! (ms , edge)
16+ return bp_cache
1417end
15- function deletemessage! (bp_cache:: AbstractBeliefPropagationCache , edge:: AbstractEdge )
16- return not_implemented ()
18+
19+ function deletemessages! (bp_cache:: AbstractGraph , edges = edges (bp_cache))
20+ for e in edges
21+ deletemessage! (bp_cache, e)
22+ end
23+ return bp_cache
1724end
18- function rescale_messages (
19- bp_cache:: AbstractBeliefPropagationCache , edges:: Vector{<:AbstractEdge} ; kwargs...
20- )
21- return not_implemented ()
25+
26+ setmessage! (bp_cache:: AbstractGraph , edge, message) = not_implemented ()
27+ function setmessage! (bp_cache:: AbstractDataGraph , edge, message)
28+ ms = messages (bp_cache)
29+ set! (ms, edge, message)
30+ return bp_cache
2231end
23- function rescale_vertices (
24- bp_cache:: AbstractBeliefPropagationCache , vertices:: Vector ; kwargs...
25- )
26- return not_implemented ()
32+ function setmessage! (bp_cache:: QuotientView , edge, message)
33+ setmessages! (parent (bp_cache), QuotientEdge (edge), message)
34+ return bp_cache
2735end
2836
29- function vertex_scalar (bp_cache:: AbstractBeliefPropagationCache , vertex; kwargs... )
30- return not_implemented ()
37+ function setmessages! (bp_cache:: AbstractGraph , edge:: QuotientEdge , message)
38+ for e in edges (bp_cache, edge)
39+ setmessage! (parent (bp_cache), e, message[e])
40+ end
41+ return bp_cache
3142end
32- function edge_scalar (
33- bp_cache:: AbstractBeliefPropagationCache , edge:: AbstractEdge ; kwargs...
34- )
35- return not_implemented ()
43+ function setmessages! (bpc_dst:: AbstractGraph , bpc_src:: AbstractGraph , edges)
44+ for e in edges
45+ setmessage! (bpc_dst, e, message (bpc_src, e))
46+ end
47+ return bpc_dst
3648end
3749
38- # Graph functionality needed
39- Graphs. vertices (bp_cache:: AbstractBeliefPropagationCache ) = not_implemented ()
40- Graphs. edges (bp_cache:: AbstractBeliefPropagationCache ) = not_implemented ()
41- function NamedGraphs. GraphsExtensions. boundary_edges (
42- bp_cache:: AbstractBeliefPropagationCache , vertices; kwargs...
43- )
44- return not_implemented ()
50+ factors (bpc:: AbstractGraph ) = vertex_data (bpc)
51+ factors (bpc:: AbstractGraph , vertices:: Vector ) = [factor (bpc, v) for v in vertices]
52+ factors (bpc:: AbstractGraph{V} , vertex:: V ) where {V} = factors (bpc, V[vertex])
53+
54+ factor (bpc:: AbstractGraph , vertex) = factors (bpc)[vertex]
55+
56+ setfactor! (bpc:: AbstractGraph , vertex, factor) = not_implemented ()
57+ function setfactor! (bpc:: AbstractDataGraph , vertex, factor)
58+ fs = factors (bpc)
59+ set! (fs, vertex, factor)
60+ return bpc
4561end
4662
47- # Functions derived from the interface
48- function setmessages! (bp_cache:: AbstractBeliefPropagationCache , edges, messages)
49- for (e, m) in zip (edges)
50- setmessage! (bp_cache, e, m)
51- end
52- return
63+ function region_scalar (bp_cache:: AbstractGraph , edge:: AbstractEdge )
64+ return message (bp_cache, edge) * message (bp_cache, reverse (edge))
5365end
5466
55- function deletemessages! (
56- bp_cache:: AbstractBeliefPropagationCache , edges:: Vector{<:AbstractEdge} = edges (bp_cache)
57- )
58- for e in edges
59- deletemessage! (bp_cache, e)
60- end
61- return bp_cache
67+ function region_scalar (bp_cache:: AbstractGraph , vertex)
68+
69+ messages = incoming_messages (bp_cache, vertex)
70+ state = factors (bp_cache, vertex)
71+
72+ return reduce (* , messages) * reduce (* , state)
6273end
6374
64- function vertex_scalars (
65- bp_cache:: AbstractBeliefPropagationCache , vertices = Graphs. vertices (bp_cache); kwargs...
66- )
67- return map (v -> region_scalar (bp_cache, v; kwargs... ), vertices)
75+ message_type (bpc:: AbstractGraph ) = message_type (typeof (bpc))
76+ message_type (G:: Type{<:AbstractGraph} ) = eltype (Base. promote_op (messages, G))
77+ message_type (type:: Type{<:AbstractDataGraph} ) = edge_data_eltype (type)
78+
79+ function vertex_scalars (bp_cache:: AbstractGraph , vertices = vertices (bp_cache))
80+ return map (v -> region_scalar (bp_cache, v), vertices)
6881end
6982
70- function edge_scalars (
71- bp_cache:: AbstractBeliefPropagationCache , edges = Graphs. edges (bp_cache); kwargs...
72- )
73- return map (e -> region_scalar (bp_cache, e; kwargs... ), edges)
83+ function edge_scalars (bp_cache:: AbstractGraph , edges = edges (bp_cache))
84+ return map (e -> region_scalar (bp_cache, e), edges)
7485end
7586
76- function scalar_factors_quotient (bp_cache:: AbstractBeliefPropagationCache )
87+ function scalar_factors_quotient (bp_cache:: AbstractGraph )
7788 return vertex_scalars (bp_cache), edge_scalars (bp_cache)
7889end
7990
80- function incoming_messages (
81- bp_cache:: AbstractBeliefPropagationCache , vertices:: Vector{<:Any} ; ignore_edges = []
82- )
83- b_edges = NamedGraphs. GraphsExtensions. boundary_edges (bp_cache, vertices; dir = :in )
91+ function incoming_messages (bp_cache:: AbstractGraph , vertices; ignore_edges = [])
92+ b_edges = boundary_edges (bp_cache, [vertices;]; dir = :in )
8493 b_edges = ! isempty (ignore_edges) ? setdiff (b_edges, ignore_edges) : b_edges
8594 return messages (bp_cache, b_edges)
8695end
8796
88- function incoming_messages (bp_cache:: AbstractBeliefPropagationCache , vertex; kwargs... )
89- return incoming_messages (bp_cache, [vertex]; kwargs... )
90- end
97+ default_messages (:: AbstractGraph ) = not_implemented ()
9198
9299# Adapt interface for changing device
93- function map_messages (f, bp_cache:: AbstractBeliefPropagationCache , es = edges (bp_cache))
94- bp_cache = copy (bp_cache)
100+ map_messages (f, bp_cache, es = edges (bp_cache)) = map_messages! (f, copy (bp_cache), es )
101+ function map_messages! (f, bp_cache, es = edges (bp_cache) )
95102 for e in es
96103 setmessage! (bp_cache, e, f (message (bp_cache, e)))
97104 end
98105 return bp_cache
99106end
100- function map_factors (f, bp_cache:: AbstractBeliefPropagationCache , vs = vertices (bp_cache))
101- bp_cache = copy (bp_cache)
107+
108+ map_factors (f, bp_cache, vs = vertices (bp_cache)) = map_factors! (f, copy (bp_cache), vs)
109+ function map_factors! (f, bp_cache, vs = vertices (bp_cache))
102110 for v in vs
103111 setfactor! (bp_cache, v, f (factor (bp_cache, v)))
104112 end
105113 return bp_cache
106114end
107- function adapt_messages (to, bp_cache:: AbstractBeliefPropagationCache , args... )
108- return map_messages (adapt (to), bp_cache, args... )
109- end
110- function adapt_factors (to, bp_cache:: AbstractBeliefPropagationCache , args... )
111- return map_factors (adapt (to), bp_cache, args... )
112- end
113115
114- function freenergy (bp_cache:: AbstractBeliefPropagationCache )
116+ adapt_messages (to, bp_cache, es = edges (bp_cache)) = map_messages (adapt (to), bp_cache, es)
117+ adapt_factors (to, bp_cache, vs = vertices (bp_cache)) = map_factors (adapt (to), bp_cache, vs)
118+
119+ abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end
120+
121+ function free_energy (bp_cache:: AbstractBeliefPropagationCache )
115122 numerator_terms, denominator_terms = scalar_factors_quotient (bp_cache)
116123 if any (t -> real (t) < 0 , numerator_terms)
117124 numerator_terms = complex .(numerator_terms)
@@ -123,29 +130,4 @@ function freenergy(bp_cache::AbstractBeliefPropagationCache)
123130 any (iszero, denominator_terms) && return - Inf
124131 return sum (log .(numerator_terms)) - sum (log .((denominator_terms)))
125132end
126-
127- function partitionfunction (bp_cache:: AbstractBeliefPropagationCache )
128- return exp (freenergy (bp_cache))
129- end
130-
131- function rescale_messages (bp_cache:: AbstractBeliefPropagationCache , edge:: AbstractEdge )
132- return rescale_messages (bp_cache, [edge])
133- end
134-
135- function rescale_messages (bp_cache:: AbstractBeliefPropagationCache )
136- return rescale_messages (bp_cache, edges (bp_cache))
137- end
138-
139- function rescale_vertices (bpc:: AbstractBeliefPropagationCache ; kwargs... )
140- return rescale_vertices (bpc, collect (vertices (bpc)); kwargs... )
141- end
142-
143- function rescale_vertex (bpc:: AbstractBeliefPropagationCache , vertex; kwargs... )
144- return rescale_vertices (bpc, [vertex]; kwargs... )
145- end
146-
147- function rescale (bpc:: AbstractBeliefPropagationCache , args... ; kwargs... )
148- bpc = rescale_messages (bpc)
149- bpc = rescale_partitions (bpc, args... ; kwargs... )
150- return bpc
151- end
133+ partitionfunction (bp_cache:: AbstractBeliefPropagationCache ) = exp (free_energy (bp_cache))
0 commit comments