Skip to content

Commit 0cc9a9d

Browse files
committed
feat(): marginals and messages form constraints
1 parent 9d6cd74 commit 0cc9a9d

File tree

2 files changed

+118
-13
lines changed

2 files changed

+118
-13
lines changed

src/backends/reactivemp.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ end
168168

169169
## Factorisations constraints specification language
170170

171-
function write_constraints_specification(::ReactiveMPBackend, factorisation, form)
172-
return :(ReactiveMP.ConstraintsSpecification($factorisation, $form))
171+
function write_constraints_specification(::ReactiveMPBackend, factorisation, marginalsform, messagesform)
172+
return :(ReactiveMP.ConstraintsSpecification($factorisation, $marginalsform, $messagesform))
173173
end
174174

175175
function write_factorisation_constraint(::ReactiveMPBackend, names, entries)
@@ -202,4 +202,8 @@ end
202202

203203
function write_factorisation_functional_index(::ReactiveMPBackend, repr, fn)
204204
return :(ReactiveMP.FunctionalIndex{$(QuoteNode(repr))}($fn))
205+
end
206+
207+
function write_form_constraint_specification(::ReactiveMPBackend, T, args, kwargs)
208+
return :(ReactiveMP.FormConstraintsSpecification($T, $args, $kwargs))
205209
end

src/constraints.jl

Lines changed: 112 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export @constraints
22

33
"""
4-
write_constraints_specification(backend, factorisation, form)
4+
write_constraints_specification(backend, factorisation, marginalsform, messagesform)
55
"""
66
function write_constraints_specification end
77

@@ -45,16 +45,81 @@ function write_factorisation_splitted_range end
4545
"""
4646
function write_factorisation_functional_index end
4747

48+
"""
49+
write_form_constraint_specification(backend, T, args, kwargs)
50+
"""
51+
function write_form_constraint_specification end
52+
4853
macro constraints(constraints_specification)
4954
return generate_constraints_expression(__get_current_backend(), constraints_specification)
5055
end
5156

52-
struct LHSMeta
57+
## Factorisation constraints
58+
59+
struct FactorisationConstraintLHSMeta
5360
name :: String
5461
hash :: UInt
5562
varname :: Symbol
5663
end
5764

65+
##
66+
67+
## Form constraints
68+
69+
function flatten_functional_form_constraint_specification(expr)
70+
return flatten_functional_form_constraint_specification!(expr, Expr(:(call), :(::)))
71+
end
72+
73+
function flatten_functional_form_constraint_specification!(symbol::Symbol, toplevel::Expr)
74+
push!(toplevel.args, symbol)
75+
return toplevel
76+
end
77+
78+
function flatten_functional_form_constraint_specification!(expr::Expr, toplevel::Expr)
79+
if ishead(expr, :(::)) && ishead(expr.args[1], :(::))
80+
flatten_functional_form_constraint_specification!(expr.args[1], toplevel)
81+
flatten_functional_form_constraint_specification!(expr.args[2], toplevel)
82+
elseif ishead(expr, :(::))
83+
push!(toplevel.args, expr.args[1])
84+
push!(toplevel.args, expr.args[2])
85+
else
86+
push!(toplevel.args, expr)
87+
end
88+
return toplevel
89+
end
90+
91+
function parse_form_constraint(backend, expr)
92+
T, args, kwargs = if expr isa Symbol
93+
expr, :(()), :((;))
94+
else
95+
if @capture(expr, f_(args__; kwargs__))
96+
f, :(($(args...), )), :((; $(kwargs...), ))
97+
elseif @capture(expr, f_(args__))
98+
99+
as = []
100+
ks = []
101+
102+
for arg in args
103+
if ishead(arg, :kw)
104+
push!(ks, arg)
105+
else
106+
push!(as, arg)
107+
end
108+
end
109+
110+
f, :(($(as...), )), :((; $(ks...), ))
111+
elseif @capture(expr, f_())
112+
f, :(()), :((;))
113+
else
114+
error("Unssuported form constraints call specification in the expression `$(expr)`")
115+
end
116+
end
117+
118+
return write_form_constraint_specification(backend, T, args, kwargs)
119+
end
120+
121+
##
122+
58123
function generate_constraints_expression(backend, constraints_specification)
59124

60125
if isblock(constraints_specification)
@@ -69,16 +134,51 @@ function generate_constraints_expression(backend, constraints_specification)
69134
cs_args = cs_args === nothing ? [] : cs_args
70135
cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
71136

72-
lhs_dict = Dict{UInt, LHSMeta}()
137+
lhs_dict = Dict{UInt, FactorisationConstraintLHSMeta}()
73138

74-
# We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
75-
form_constraints_symbol = gensym(:form_constraint)
76-
form_constraints_symbol_init = :($form_constraints_symbol = ())
139+
marginals_form_constraints_symbol = gensym(:marginals_form_constraint)
140+
marginals_form_constraints_symbol_init = :($marginals_form_constraints_symbol = (;))
141+
142+
messages_form_constraints_symbol = gensym(:messages_form_constraint)
143+
messages_form_constraints_symbol_init = :($messages_form_constraints_symbol = (;))
77144

78-
# We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
79145
factorisation_constraints_symbol = gensym(:factorisation_constraint)
80146
factorisation_constraints_symbol_init = :($factorisation_constraints_symbol = ())
81-
147+
148+
# First we modify form constraints related statements
149+
cs_body = prewalk(cs_body) do expression
150+
if ishead(expression, :(::))
151+
return flatten_functional_form_constraint_specification(expression)
152+
end
153+
return expression
154+
end
155+
156+
cs_body = prewalk(cs_body) do expression
157+
if iscall(expression, :(::))
158+
if @capture(expression.args[2], q(formsym_Symbol))
159+
specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args)))
160+
return quote
161+
if haskey($marginals_form_constraints_symbol, $(QuoteNode(formsym)))
162+
error("Marginal form constraint q($(formsym)) has been redefined.")
163+
end
164+
$marginals_form_constraints_symbol = (; $marginals_form_constraints_symbol..., $formsym = ($(specs... ),))
165+
end
166+
elseif @capture(expression.args[2], μ(formsym_Symbol))
167+
specs = map((e) -> parse_form_constraint(backend, e), view(expression.args, 3:lastindex(expression.args)))
168+
return quote
169+
if haskey($messages_form_constraints_symbol, $(QuoteNode(formsym)))
170+
error("Messages form constraint μ($(formsym)) has been redefined.")
171+
end
172+
$messages_form_constraints_symbol = (; $messages_form_constraints_symbol..., $formsym = ($(specs... ),))
173+
end
174+
else
175+
error("Invalid form factorisation constraint. $(expression.args[2]) has to be in the form of q(varname) for marginal form constraint or μ(varname) for messages form constraint.")
176+
end
177+
end
178+
return expression
179+
end
180+
181+
# Second we modify factorisation constraints related statements
82182
# First we record all lhs expression's hash ids and create unique variable names for them
83183
# q(x, y) = q(x)q(y) -> hash(q(x, y))
84184
# We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
@@ -128,7 +228,7 @@ function generate_constraints_expression(backend, constraints_specification)
128228
else
129229
lhs_name = string("q(", join(names, ", "), ")")
130230
lhs_varname = gensym(lhs_name)
131-
lhs_meta = LHSMeta(lhs_name, lhs_hash, lhs_varname)
231+
lhs_meta = FactorisationConstraintLHSMeta(lhs_name, lhs_hash, lhs_varname)
132232
lhs_dict[lhs_hash] = lhs_meta
133233
end
134234

@@ -199,11 +299,12 @@ function generate_constraints_expression(backend, constraints_specification)
199299
return expression
200300
end
201301

202-
return_specification = write_constraints_specification(backend, factorisation_constraints_symbol, form_constraints_symbol)
302+
return_specification = write_constraints_specification(backend, factorisation_constraints_symbol, marginals_form_constraints_symbol, messages_form_constraints_symbol)
203303

204304
res = quote
205305
function $cs_name($(cs_args...); $(cs_kwargs...))
206-
$(form_constraints_symbol_init)
306+
$(marginals_form_constraints_symbol_init)
307+
$(messages_form_constraints_symbol_init)
207308
$(factorisation_constraints_symbol_init)
208309
$(cs_lhs_init_block...)
209310
$(cs_body)

0 commit comments

Comments
 (0)