11export @constraints
22
33"""
4- write_constraints_specification(backend, factorisation, form )
4+ write_constraints_specification(backend, factorisation, marginalsform, messagesform )
55"""
66function write_constraints_specification end
77
@@ -45,16 +45,81 @@ function write_factorisation_splitted_range end
4545"""
4646function write_factorisation_functional_index end
4747
48+ """
49+ write_form_constraint_specification(backend, T, args, kwargs)
50+ """
51+ function write_form_constraint_specification end
52+
4853macro constraints (constraints_specification)
4954 return generate_constraints_expression (__get_current_backend (), constraints_specification)
5055end
5156
52- struct LHSMeta
57+ # # Factorisation constraints
58+
59+ struct FactorisationConstraintLHSMeta
5360 name :: String
5461 hash :: UInt
5562 varname :: Symbol
5663end
5764
65+ # #
66+
67+ # # Form constraints
68+
69+ function flatten_functional_form_constraint_specification (expr)
70+ return flatten_functional_form_constraint_specification! (expr, Expr (:(call), :(:: )))
71+ end
72+
73+ function flatten_functional_form_constraint_specification! (symbol:: Symbol , toplevel:: Expr )
74+ push! (toplevel. args, symbol)
75+ return toplevel
76+ end
77+
78+ function flatten_functional_form_constraint_specification! (expr:: Expr , toplevel:: Expr )
79+ if ishead (expr, :(:: )) && ishead (expr. args[1 ], :(:: ))
80+ flatten_functional_form_constraint_specification! (expr. args[1 ], toplevel)
81+ flatten_functional_form_constraint_specification! (expr. args[2 ], toplevel)
82+ elseif ishead (expr, :(:: ))
83+ push! (toplevel. args, expr. args[1 ])
84+ push! (toplevel. args, expr. args[2 ])
85+ else
86+ push! (toplevel. args, expr)
87+ end
88+ return toplevel
89+ end
90+
91+ function parse_form_constraint (backend, expr)
92+ T, args, kwargs = if expr isa Symbol
93+ expr, :(()), :((;))
94+ else
95+ if @capture (expr, f_ (args__; kwargs__))
96+ f, :(($ (args... ), )), :((; $ (kwargs... ), ))
97+ elseif @capture (expr, f_ (args__))
98+
99+ as = []
100+ ks = []
101+
102+ for arg in args
103+ if ishead (arg, :kw )
104+ push! (ks, arg)
105+ else
106+ push! (as, arg)
107+ end
108+ end
109+
110+ f, :(($ (as... ), )), :((; $ (ks... ), ))
111+ elseif @capture (expr, f_ ())
112+ f, :(()), :((;))
113+ else
114+ error (" Unssuported form constraints call specification in the expression `$(expr) `" )
115+ end
116+ end
117+
118+ return write_form_constraint_specification (backend, T, args, kwargs)
119+ end
120+
121+ # #
122+
58123function generate_constraints_expression (backend, constraints_specification)
59124
60125 if isblock (constraints_specification)
@@ -69,16 +134,51 @@ function generate_constraints_expression(backend, constraints_specification)
69134 cs_args = cs_args === nothing ? [] : cs_args
70135 cs_kwargs = cs_kwargs === nothing ? [] : cs_kwargs
71136
72- lhs_dict = Dict {UInt, LHSMeta } ()
137+ lhs_dict = Dict {UInt, FactorisationConstraintLHSMeta } ()
73138
74- # We iteratively overwrite extend form constraint tuple, but we use different names for it to enable type-stability
75- form_constraints_symbol = gensym (:form_constraint )
76- form_constraints_symbol_init = :($ form_constraints_symbol = ())
139+ marginals_form_constraints_symbol = gensym (:marginals_form_constraint )
140+ marginals_form_constraints_symbol_init = :($ marginals_form_constraints_symbol = (;))
141+
142+ messages_form_constraints_symbol = gensym (:messages_form_constraint )
143+ messages_form_constraints_symbol_init = :($ messages_form_constraints_symbol = (;))
77144
78- # We iteratively overwrite extend factorisation constraint tuple, but we use different names for it to enable type-stability
79145 factorisation_constraints_symbol = gensym (:factorisation_constraint )
80146 factorisation_constraints_symbol_init = :($ factorisation_constraints_symbol = ())
81-
147+
148+ # First we modify form constraints related statements
149+ cs_body = prewalk (cs_body) do expression
150+ if ishead (expression, :(:: ))
151+ return flatten_functional_form_constraint_specification (expression)
152+ end
153+ return expression
154+ end
155+
156+ cs_body = prewalk (cs_body) do expression
157+ if iscall (expression, :(:: ))
158+ if @capture (expression. args[2 ], q (formsym_Symbol))
159+ specs = map ((e) -> parse_form_constraint (backend, e), view (expression. args, 3 : lastindex (expression. args)))
160+ return quote
161+ if haskey ($ marginals_form_constraints_symbol, $ (QuoteNode (formsym)))
162+ error (" Marginal form constraint q($(formsym) ) has been redefined." )
163+ end
164+ $ marginals_form_constraints_symbol = (; $ marginals_form_constraints_symbol... , $ formsym = ($ (specs... ),))
165+ end
166+ elseif @capture (expression. args[2 ], μ (formsym_Symbol))
167+ specs = map ((e) -> parse_form_constraint (backend, e), view (expression. args, 3 : lastindex (expression. args)))
168+ return quote
169+ if haskey ($ messages_form_constraints_symbol, $ (QuoteNode (formsym)))
170+ error (" Messages form constraint μ($(formsym) ) has been redefined." )
171+ end
172+ $ messages_form_constraints_symbol = (; $ messages_form_constraints_symbol... , $ formsym = ($ (specs... ),))
173+ end
174+ else
175+ error (" Invalid form factorisation constraint. $(expression. args[2 ]) has to be in the form of q(varname) for marginal form constraint or μ(varname) for messages form constraint." )
176+ end
177+ end
178+ return expression
179+ end
180+
181+ # Second we modify factorisation constraints related statements
82182 # First we record all lhs expression's hash ids and create unique variable names for them
83183 # q(x, y) = q(x)q(y) -> hash(q(x, y))
84184 # We do allow multiple definitions in case of if statements, but we do check later overwrites, which are not allowed
@@ -128,7 +228,7 @@ function generate_constraints_expression(backend, constraints_specification)
128228 else
129229 lhs_name = string (" q(" , join (names, " , " ), " )" )
130230 lhs_varname = gensym (lhs_name)
131- lhs_meta = LHSMeta (lhs_name, lhs_hash, lhs_varname)
231+ lhs_meta = FactorisationConstraintLHSMeta (lhs_name, lhs_hash, lhs_varname)
132232 lhs_dict[lhs_hash] = lhs_meta
133233 end
134234
@@ -199,11 +299,12 @@ function generate_constraints_expression(backend, constraints_specification)
199299 return expression
200300 end
201301
202- return_specification = write_constraints_specification (backend, factorisation_constraints_symbol, form_constraints_symbol )
302+ return_specification = write_constraints_specification (backend, factorisation_constraints_symbol, marginals_form_constraints_symbol, messages_form_constraints_symbol )
203303
204304 res = quote
205305 function $cs_name ($ (cs_args... ); $ (cs_kwargs... ))
206- $ (form_constraints_symbol_init)
306+ $ (marginals_form_constraints_symbol_init)
307+ $ (messages_form_constraints_symbol_init)
207308 $ (factorisation_constraints_symbol_init)
208309 $ (cs_lhs_init_block... )
209310 $ (cs_body)
0 commit comments