|
1 | 1 | module GraphPPL |
2 | 2 |
|
3 | | -export @model |
4 | | - |
5 | | -import MacroTools |
6 | | -import MacroTools: @capture, postwalk, prewalk, walk |
7 | | - |
8 | | -function conditioned_walk(f, condition_skip, condition_apply, x) |
9 | | - walk(x, x -> condition_skip(x) ? x : condition_apply(x) ? f(x) : conditioned_walk(f, condition_skip, condition_apply, x), identity) |
10 | | -end |
11 | | - |
12 | | -""" |
13 | | - fquote(expr) |
14 | | -
|
15 | | -This function forces `Expr` or `Symbol` to be quoted. |
16 | | -""" |
17 | | -fquote(expr::Symbol) = Expr(:quote, expr) |
18 | | -fquote(expr::Int) = expr |
19 | | -fquote(expr::Expr) = expr |
20 | | - |
21 | | -""" |
22 | | - ensure_type |
23 | | -""" |
24 | | -ensure_type(x::Type) = x |
25 | | -ensure_type(x) = error("Valid type object was expected but '$x' has been found") |
26 | | - |
27 | | -is_kwargs_expression(x) = false |
28 | | -is_kwargs_expression(x::Expr) = x.head === :parameters |
29 | | - |
30 | | -""" |
31 | | - parse_varexpr(varexpr) |
32 | | -
|
33 | | -This function parses variable id and returns a tuple of 3 different representations of the same variable |
34 | | -1. Original expression |
35 | | -2. Short variable identificator (used in variables lookup table) |
36 | | -3. Full variable identificator (used in model as a variable id) |
37 | | -""" |
38 | | -function parse_varexpr(varexpr::Symbol) |
39 | | - varexpr = varexpr |
40 | | - short_id = varexpr |
41 | | - full_id = varexpr |
42 | | - return varexpr, short_id, full_id |
43 | | -end |
44 | | - |
45 | | -function parse_varexpr(varexpr::Expr) |
46 | | - |
47 | | - # TODO: It might be handy to have this feature in the future for e.g. interacting with UnPack.jl package |
48 | | - # TODO: For now however we fallback to a more informative error message since it is not obvious how to parse such expressions yet |
49 | | - @capture(varexpr, (tupled_ids__, )) && |
50 | | - error("Multiple variable declarations, definitions and assigments are forbidden within @model macro. Try to split $(varexpr) into several independent statements.") |
51 | | - |
52 | | - @capture(varexpr, id_[idx__]) || |
53 | | - error("Variable identificator can be in form of a single symbol (x ~ ...) or indexing expression (x[i] ~ ...)") |
54 | | - |
55 | | - varexpr = varexpr |
56 | | - short_id = id |
57 | | - full_id = Expr(:call, :Symbol, fquote(id), Expr(:quote, :_), Expr(:quote, Symbol(join(idx, :_)))) |
58 | | - |
59 | | - return varexpr, short_id, full_id |
60 | | -end |
61 | | - |
62 | | -""" |
63 | | - normalize_tilde_arguments(args) |
64 | | -
|
65 | | -This function 'normalizes' every argument of a tilde expression making every inner function call to be a tilde expression as well. |
66 | | -It forces MSL to create anonymous node for any non-linear variable transformation or deterministic relationships. MSL does not check (and cannot in general) |
67 | | -if some inner function call leads to a constant expression or not (e.g. `Normal(0.0, sqrt(10.0))`). Backend API should decide whenever to create additional anonymous nodes |
68 | | -for constant non-linear transformation expressions or not by analyzing input arguments. |
69 | | -""" |
70 | | -function normalize_tilde_arguments(args) |
71 | | - return map(args) do arg |
72 | | - if @capture(arg, id_[idx_]) |
73 | | - return :($(__normalize_arg(id))[$idx]) |
74 | | - else |
75 | | - return __normalize_arg(arg) |
76 | | - end |
77 | | - end |
78 | | -end |
79 | | - |
80 | | -function __normalize_arg(arg) |
81 | | - if @capture(arg, (f_(v__) where { options__ }) | (f_(v__))) |
82 | | - if f === :(|>) |
83 | | - @assert length(v) === 2 "Unsupported pipe syntax in model specification: $(arg)" |
84 | | - f = v[2] |
85 | | - v = [ v[1] ] |
86 | | - end |
87 | | - nvarexpr = gensym(:nvar) |
88 | | - nnodeexpr = gensym(:nnode) |
89 | | - options = options !== nothing ? options : [] |
90 | | - v = normalize_tilde_arguments(v) |
91 | | - return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $nvarexpr) |
92 | | - else |
93 | | - return arg |
94 | | - end |
95 | | -end |
96 | | - |
97 | | -argument_write_default_value(arg, default::Nothing) = arg |
98 | | -argument_write_default_value(arg, default) = Expr(:kw, arg, default) |
99 | | - |
100 | | - |
101 | | -""" |
102 | | - write_argument_guard(backend, argument) |
103 | | -""" |
104 | | -function write_argument_guard end |
105 | | - |
106 | | -""" |
107 | | - write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) |
108 | | -""" |
109 | | -function write_randomvar_expression end |
110 | | - |
111 | | -""" |
112 | | - write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments) |
113 | | -""" |
114 | | -function write_datavar_expression end |
115 | | - |
116 | | -""" |
117 | | - write_constvar_expression(backend, model, varexpr, arguments, kwarguments) |
118 | | -""" |
119 | | -function write_constvar_expression end |
120 | | - |
121 | | -""" |
122 | | - write_as_variable(backend, model, varexpr) |
123 | | -""" |
124 | | -function write_as_variable end |
125 | | - |
126 | | -""" |
127 | | - write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr) |
128 | | -""" |
129 | | -function write_make_node_expression end |
130 | | - |
131 | | -""" |
132 | | - write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, autovarid) |
133 | | -""" |
134 | | -function write_autovar_make_node_expression end |
135 | | - |
136 | | -""" |
137 | | - write_node_options(backend, fform, variables, options) |
138 | | -""" |
139 | | -function write_node_options end |
140 | | - |
141 | | -""" |
142 | | - write_randomvar_options(backend, variable, options) |
143 | | -""" |
144 | | -function write_randomvar_options end |
145 | | - |
146 | | -""" |
147 | | - write_constvar_options(backend, variable, options) |
148 | | -""" |
149 | | -function write_constvar_options end |
150 | | - |
151 | | -""" |
152 | | - write_datavar_options(backend, variable, options) |
153 | | -""" |
154 | | -function write_datavar_options end |
| 3 | +using MacroTools |
155 | 4 |
|
156 | 5 | include("backends/reactivemp.jl") |
157 | 6 |
|
158 | 7 | __get_current_backend() = ReactiveMPBackend() |
159 | 8 |
|
160 | | -macro model(model_specification) |
161 | | - return esc(:(@model [] $model_specification)) |
162 | | -end |
163 | | - |
164 | | -macro model(model_options, model_specification) |
165 | | - return GraphPPL.generate_model_expression(__get_current_backend(), model_options, model_specification) |
166 | | -end |
167 | | - |
168 | | -function generate_model_expression(backend, model_options, model_specification) |
169 | | - @capture(model_options, [ ms_options__ ]) || |
170 | | - error("Model specification options should be in a form of [ option1 = ..., option2 = ... ]") |
171 | | - |
172 | | - ms_options = map(ms_options) do option |
173 | | - (@capture(option, name_ = value_) && name isa Symbol) || error("Invalid option specification: $(option). Expected: 'option_name = option_value'.") |
174 | | - return (name, value) |
175 | | - end |
176 | | - |
177 | | - ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),))) |
178 | | - |
179 | | - @capture(model_specification, (function ms_name_(ms_args__; ms_kwargs__) ms_body_ end) | (function ms_name_(ms_args__) ms_body_ end)) || |
180 | | - error("Model specification language requires full function definition") |
181 | | - |
182 | | - model = gensym(:model) |
183 | | - |
184 | | - ms_args_ids = Vector{Symbol}() |
185 | | - ms_args_guard_ids = Vector{Symbol}() |
186 | | - ms_args_const_ids = Vector{Tuple{Symbol, Symbol}}() |
187 | | - |
188 | | - ms_arg_expression_converter = (ms_arg) -> begin |
189 | | - if @capture(ms_arg, arg_::ConstVariable = smth_) || @capture(ms_arg, arg_::ConstVariable) |
190 | | - # rc_arg = gensym(:constvar) |
191 | | - push!(ms_args_const_ids, (arg, arg)) # backward compatibility for old behaviour with gensym |
192 | | - push!(ms_args_guard_ids, arg) |
193 | | - push!(ms_args_ids, arg) |
194 | | - return argument_write_default_value(arg, smth) |
195 | | - elseif @capture(ms_arg, arg_::T_ = smth_) || @capture(ms_arg, arg_::T_) |
196 | | - push!(ms_args_guard_ids, arg) |
197 | | - push!(ms_args_ids, arg) |
198 | | - return argument_write_default_value(:($(arg)::$(T)), smth) |
199 | | - elseif @capture(ms_arg, arg_Symbol = smth_) || @capture(ms_arg, arg_Symbol) |
200 | | - push!(ms_args_guard_ids, arg) |
201 | | - push!(ms_args_ids, arg) |
202 | | - return argument_write_default_value(arg, smth) |
203 | | - else |
204 | | - error("Invalid argument specification: $(ms_arg)") |
205 | | - end |
206 | | - end |
207 | | - |
208 | | - ms_args = ms_args === nothing ? [] : map(ms_arg_expression_converter, ms_args) |
209 | | - ms_kwargs = ms_kwargs === nothing ? [] : map(ms_arg_expression_converter, ms_kwargs) |
210 | | - |
211 | | - if length(Set(ms_args_ids)) !== length(ms_args_ids) |
212 | | - error("There are duplicates in argument specification list: $(ms_args_ids)") |
213 | | - end |
214 | | - |
215 | | - ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id |
216 | | - return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ], []) |
217 | | - end |
218 | | - |
219 | | - # Step 0: Check that all inputs are not AbstractVariables |
220 | | - # It is highly recommended not to create AbstractVariables outside of the model creation macro |
221 | | - # Doing so can lead to undefined behaviour |
222 | | - ms_args_checks = map((ms_arg) -> write_argument_guard(backend, ms_arg), ms_args_guard_ids) |
223 | | - |
224 | | - # Step 1: Probabilistic arguments normalisation |
225 | | - ms_body = prewalk(ms_body) do expression |
226 | | - if @capture(expression, (varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__))) |
227 | | - options = options === nothing ? [] : options |
228 | | - |
229 | | - # Filter out keywords arguments to options array |
230 | | - arguments = filter(arguments) do arg |
231 | | - ifparameters = arg isa Expr && arg.head === :parameters |
232 | | - if ifparameters |
233 | | - foreach(a -> push!(options, a), arg.args) |
234 | | - end |
235 | | - return !ifparameters |
236 | | - end |
237 | | - |
238 | | - varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr)) |
239 | | - return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...))) |
240 | | - elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ }) |
241 | | - return :($varexpr = randomvar($(arguments...); $(write_randomvar_options(backend, varexpr, options)...))) |
242 | | - elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ }) |
243 | | - return :($varexpr = datavar($(arguments...); $(write_datavar_options(backend, varexpr, options)...))) |
244 | | - elseif @capture(expression, varexpr_ = constvar(arguments__) where { options__ }) |
245 | | - return :($varexpr = constvar($(arguments...); $(write_constvar_options(backend, varexpr, options)...))) |
246 | | - elseif @capture(expression, varexpr_ = randomvar(arguments__)) |
247 | | - return :($varexpr = randomvar($(arguments...); )) |
248 | | - elseif @capture(expression, varexpr_ = datavar(arguments__)) |
249 | | - return :($varexpr = datavar($(arguments...); )) |
250 | | - elseif @capture(expression, varexpr_ = constvar(arguments__)) |
251 | | - return :($varexpr = constvar($(arguments...); )) |
252 | | - else |
253 | | - return expression |
254 | | - end |
255 | | - end |
256 | | - |
257 | | - bannedids = Set{Symbol}() |
258 | | - |
259 | | - ms_body = postwalk(ms_body) do expression |
260 | | - if @capture(expression, lhs_ = rhs_) |
261 | | - if !(@capture(rhs, datavar(args__))) && !(@capture(rhs, randomvar(args__))) && !(@capture(rhs, constvar(args__))) |
262 | | - varexpr, short_id, full_id = parse_varexpr(lhs) |
263 | | - push!(bannedids, short_id) |
264 | | - end |
265 | | - end |
266 | | - return expression |
267 | | - end |
268 | | - |
269 | | - varids = Set{Symbol}(ms_args_ids) |
270 | | - |
271 | | - # Step 2: Main pass |
272 | | - ms_body = postwalk(ms_body) do expression |
273 | | - # Step 2.1 Convert datavar calls |
274 | | - if @capture(expression, varexpr_ = datavar(arguments__; kwarguments__)) |
275 | | - @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" |
276 | | - @assert length(arguments) >= 1 "datavar() call requires type specification as a first argument" |
277 | | - |
278 | | - push!(varids, varexpr) |
279 | | - |
280 | | - type_argument = arguments[1] |
281 | | - tail_arguments = arguments[2:end] |
282 | | - |
283 | | - return write_datavar_expression(backend, model, varexpr, type_argument, tail_arguments, kwarguments) |
284 | | - # Step 2.2 Convert randomvar calls |
285 | | - elseif @capture(expression, varexpr_ = randomvar(arguments__; kwarguments__)) |
286 | | - @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" |
287 | | - push!(varids, varexpr) |
288 | | - |
289 | | - return write_randomvar_expression(backend, model, varexpr, arguments, kwarguments) |
290 | | - # Step 2.3 Conver constvar calls |
291 | | - elseif @capture(expression, varexpr_ = constvar(arguments__; kwarguments__)) |
292 | | - @assert varexpr ∉ varids "Invalid model specification: '$varexpr' id is duplicated" |
293 | | - push!(varids, varexpr) |
294 | | - |
295 | | - return write_constvar_expression(backend, model, varexpr, arguments, kwarguments) |
296 | | - # Step 2.2 Convert tilde expressions |
297 | | - elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__)) |
298 | | - # println(expression) |
299 | | - varexpr, short_id, full_id = parse_varexpr(varexpr) |
300 | | - |
301 | | - if short_id ∈ bannedids |
302 | | - error("Invalid name '$(short_id)' for new random variable. '$(short_id)' was already initialized with '=' operator before.") |
303 | | - end |
304 | | - |
305 | | - variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments) |
306 | | - options = write_node_options(backend, fform, [ varexpr, arguments... ], kwarguments) |
307 | | - |
308 | | - if short_id ∈ varids |
309 | | - return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr) |
310 | | - else |
311 | | - push!(varids, short_id) |
312 | | - return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id) |
313 | | - end |
314 | | - else |
315 | | - return expression |
316 | | - end |
317 | | - end |
318 | | - |
319 | | - # Step 3: Final pass |
320 | | - final_pass_exceptions = (x) -> @capture(x, (some_ -> body_) | (function some_(args__) body_ end) | (some_(args__) = body_)) |
321 | | - final_pass_target = (x) -> @capture(x, return ret_) |
322 | | - |
323 | | - ms_body = conditioned_walk(final_pass_exceptions, final_pass_target, ms_body) do expression |
324 | | - @capture(expression, return ret_) ? quote activate!($model); return $model, ($ret) end : expression |
325 | | - end |
326 | | - |
327 | | - res = quote |
328 | | - |
329 | | - function $ms_name($(ms_args...); $(ms_kwargs...), options = $(ms_options)) |
330 | | - $(ms_args_checks...) |
331 | | - options = merge($(ms_options), options) |
332 | | - $model = Model(options) |
333 | | - $(ms_args_const_init_block...) |
334 | | - $ms_body |
335 | | - error("'return' statement is missing") |
336 | | - end |
337 | | - end |
338 | | - |
339 | | - return esc(res) |
340 | | -end |
| 9 | +include("utils.jl") |
| 10 | +include("model.jl") |
| 11 | +include("constraints.jl") |
| 12 | +include("meta.jl") |
341 | 13 |
|
342 | 14 | end # module |
0 commit comments