@@ -104,17 +104,17 @@ argument_write_default_value(arg, default) = Expr(:kw, arg, default)
104104function write_argument_guard end
105105
106106"""
107- write_randomvar_expression(backend, model, varexpr, arguments)
107+ write_randomvar_expression(backend, model, varexpr, arguments, kwarguments )
108108"""
109109function write_randomvar_expression end
110110
111111"""
112- write_datavar_expression(backend, model, varexpr, type, arguments)
112+ write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments )
113113"""
114114function write_datavar_expression end
115115
116116"""
117- write_constvar_expression(backend, model, varexpr, arguments)
117+ write_constvar_expression(backend, model, varexpr, arguments, kwarguments )
118118"""
119119function write_constvar_expression end
120120
@@ -138,6 +138,21 @@ function write_autovar_make_node_expression end
138138"""
139139function write_node_options end
140140
141+ """
142+ write_randomvar_options(backend, variable, options)
143+ """
144+ function write_randomvar_options end
145+
146+ """
147+ write_constvar_options(backend, variable, options)
148+ """
149+ function write_constvar_options end
150+
151+ """
152+ write_datavar_options(backend, variable, options)
153+ """
154+ function write_datavar_options end
155+
141156include (" backends/reactivemp.jl" )
142157
143158__get_current_backend () = ReactiveMPBackend ()
@@ -198,7 +213,7 @@ function generate_model_expression(backend, model_options, model_specification)
198213 end
199214
200215 ms_args_const_init_block = map (ms_args_const_ids) do ms_arg_const_id
201- return write_constvar_expression (backend, model, first (ms_arg_const_id), [ last (ms_arg_const_id) ])
216+ return write_constvar_expression (backend, model, first (ms_arg_const_id), [ last (ms_arg_const_id) ], [] )
202217 end
203218
204219 # Step 0: Check that all inputs are not AbstractVariables
@@ -222,6 +237,18 @@ function generate_model_expression(backend, model_options, model_specification)
222237
223238 varexpr = @capture (varexpr, (nodeid_, varid_)) ? varexpr : :(($ (gensym (:nnode )), $ varexpr))
224239 return :($ varexpr ~ $ (fform)($ ((normalize_tilde_arguments (arguments)). .. ); $ (options... )))
240+ elseif @capture (expression, varexpr_ = randomvar (arguments__) where { options__ })
241+ return :($ varexpr = randomvar ($ (arguments... ); $ (write_randomvar_options (backend, varexpr, options)... )))
242+ elseif @capture (expression, varexpr_ = datavar (arguments__) where { options__ })
243+ return :($ varexpr = datavar ($ (arguments... ); $ (write_datavar_options (backend, varexpr, options)... )))
244+ elseif @capture (expression, varexpr_ = constvar (arguments__) where { options__ })
245+ return :($ varexpr = constvar ($ (arguments... ); $ (write_constvar_options (backend, varexpr, options)... )))
246+ elseif @capture (expression, varexpr_ = randomvar (arguments__))
247+ return :($ varexpr = randomvar ($ (arguments... ); ))
248+ elseif @capture (expression, varexpr_ = datavar (arguments__))
249+ return :($ varexpr = datavar ($ (arguments... ); ))
250+ elseif @capture (expression, varexpr_ = constvar (arguments__))
251+ return :($ varexpr = constvar ($ (arguments... ); ))
225252 else
226253 return expression
227254 end
@@ -244,28 +271,28 @@ function generate_model_expression(backend, model_options, model_specification)
244271 # Step 2: Main pass
245272 ms_body = postwalk (ms_body) do expression
246273 # Step 2.1 Convert datavar calls
247- if @capture (expression, varexpr_ = datavar (arguments__))
274+ if @capture (expression, varexpr_ = datavar (arguments__; kwarguments__ ))
248275 @assert varexpr ∉ varids " Invalid model specification: '$varexpr ' id is duplicated"
249276 @assert length (arguments) >= 1 " datavar() call requires type specification as a first argument"
250277
251278 push! (varids, varexpr)
252279
253- type = arguments[1 ]
254- tail = arguments[2 : end ]
280+ type_argument = arguments[1 ]
281+ tail_arguments = arguments[2 : end ]
255282
256- return write_datavar_expression (backend, model, varexpr, type, tail )
283+ return write_datavar_expression (backend, model, varexpr, type_argument, tail_arguments, kwarguments )
257284 # Step 2.2 Convert randomvar calls
258- elseif @capture (expression, varexpr_ = randomvar (arguments__))
285+ elseif @capture (expression, varexpr_ = randomvar (arguments__; kwarguments__ ))
259286 @assert varexpr ∉ varids " Invalid model specification: '$varexpr ' id is duplicated"
260287 push! (varids, varexpr)
261288
262- return write_randomvar_expression (backend, model, varexpr, arguments)
289+ return write_randomvar_expression (backend, model, varexpr, arguments, kwarguments )
263290 # Step 2.3 Conver constvar calls
264- elseif @capture (expression, varexpr_ = constvar (arguments__))
291+ elseif @capture (expression, varexpr_ = constvar (arguments__; kwarguments__ ))
265292 @assert varexpr ∉ varids " Invalid model specification: '$varexpr ' id is duplicated"
266293 push! (varids, varexpr)
267294
268- return write_constvar_expression (backend, model, varexpr, arguments)
295+ return write_constvar_expression (backend, model, varexpr, arguments, kwarguments )
269296 # Step 2.2 Convert tilde expressions
270297 elseif @capture (expression, (nodeexpr_, varexpr_) ~ fform_ (arguments__; kwarguments__))
271298 # println(expression)
0 commit comments