Skip to content

Commit 0d3ab17

Browse files
committed
feat(): where options support for datavar/constvar/datavar
1 parent c384e12 commit 0d3ab17

File tree

2 files changed

+68
-18
lines changed

2 files changed

+68
-18
lines changed

src/GraphPPL.jl

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,17 @@ argument_write_default_value(arg, default) = Expr(:kw, arg, default)
104104
function write_argument_guard end
105105

106106
"""
107-
write_randomvar_expression(backend, model, varexpr, arguments)
107+
write_randomvar_expression(backend, model, varexpr, arguments, kwarguments)
108108
"""
109109
function write_randomvar_expression end
110110

111111
"""
112-
write_datavar_expression(backend, model, varexpr, type, arguments)
112+
write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments)
113113
"""
114114
function write_datavar_expression end
115115

116116
"""
117-
write_constvar_expression(backend, model, varexpr, arguments)
117+
write_constvar_expression(backend, model, varexpr, arguments, kwarguments)
118118
"""
119119
function write_constvar_expression end
120120

@@ -138,6 +138,21 @@ function write_autovar_make_node_expression end
138138
"""
139139
function write_node_options end
140140

141+
"""
142+
write_randomvar_options(backend, variable, options)
143+
"""
144+
function write_randomvar_options end
145+
146+
"""
147+
write_constvar_options(backend, variable, options)
148+
"""
149+
function write_constvar_options end
150+
151+
"""
152+
write_datavar_options(backend, variable, options)
153+
"""
154+
function write_datavar_options end
155+
141156
include("backends/reactivemp.jl")
142157

143158
__get_current_backend() = ReactiveMPBackend()
@@ -198,7 +213,7 @@ function generate_model_expression(backend, model_options, model_specification)
198213
end
199214

200215
ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id
201-
return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ])
216+
return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ], [])
202217
end
203218

204219
# Step 0: Check that all inputs are not AbstractVariables
@@ -222,6 +237,18 @@ function generate_model_expression(backend, model_options, model_specification)
222237

223238
varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr))
224239
return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...)))
240+
elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ })
241+
return :($varexpr = randomvar($(arguments...); $(write_randomvar_options(backend, varexpr, options)...)))
242+
elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ })
243+
return :($varexpr = datavar($(arguments...); $(write_datavar_options(backend, varexpr, options)...)))
244+
elseif @capture(expression, varexpr_ = constvar(arguments__) where { options__ })
245+
return :($varexpr = constvar($(arguments...); $(write_constvar_options(backend, varexpr, options)...)))
246+
elseif @capture(expression, varexpr_ = randomvar(arguments__))
247+
return :($varexpr = randomvar($(arguments...); ))
248+
elseif @capture(expression, varexpr_ = datavar(arguments__))
249+
return :($varexpr = datavar($(arguments...); ))
250+
elseif @capture(expression, varexpr_ = constvar(arguments__))
251+
return :($varexpr = constvar($(arguments...); ))
225252
else
226253
return expression
227254
end
@@ -244,28 +271,28 @@ function generate_model_expression(backend, model_options, model_specification)
244271
# Step 2: Main pass
245272
ms_body = postwalk(ms_body) do expression
246273
# Step 2.1 Convert datavar calls
247-
if @capture(expression, varexpr_ = datavar(arguments__))
274+
if @capture(expression, varexpr_ = datavar(arguments__; kwarguments__))
248275
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
249276
@assert length(arguments) >= 1 "datavar() call requires type specification as a first argument"
250277

251278
push!(varids, varexpr)
252279

253-
type = arguments[1]
254-
tail = arguments[2:end]
280+
type_argument = arguments[1]
281+
tail_arguments = arguments[2:end]
255282

256-
return write_datavar_expression(backend, model, varexpr, type, tail)
283+
return write_datavar_expression(backend, model, varexpr, type_argument, tail_arguments, kwarguments)
257284
# Step 2.2 Convert randomvar calls
258-
elseif @capture(expression, varexpr_ = randomvar(arguments__))
285+
elseif @capture(expression, varexpr_ = randomvar(arguments__; kwarguments__))
259286
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
260287
push!(varids, varexpr)
261288

262-
return write_randomvar_expression(backend, model, varexpr, arguments)
289+
return write_randomvar_expression(backend, model, varexpr, arguments, kwarguments)
263290
# Step 2.3 Conver constvar calls
264-
elseif @capture(expression, varexpr_ = constvar(arguments__))
291+
elseif @capture(expression, varexpr_ = constvar(arguments__; kwarguments__))
265292
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
266293
push!(varids, varexpr)
267294

268-
return write_constvar_expression(backend, model, varexpr, arguments)
295+
return write_constvar_expression(backend, model, varexpr, arguments, kwarguments)
269296
# Step 2.2 Convert tilde expressions
270297
elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__))
271298
# println(expression)

src/backends/reactivemp.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ function write_argument_guard(::ReactiveMPBackend, argument::Symbol)
88
return :(@assert !($argument isa ReactiveMP.AbstractVariable) "It is not allowed to pass AbstractVariable objects to a model definition arguments. ConstVariables should be passed as their raw values.")
99
end
1010

11-
function write_randomvar_expression(::ReactiveMPBackend, model, varexp, arguments)
12-
return :($varexp = ReactiveMP.randomvar($model, $(GraphPPL.fquote(varexp)), $(arguments...)))
11+
function write_randomvar_expression(::ReactiveMPBackend, model, varexp, arguments, kwarguments)
12+
return :($varexp = ReactiveMP.randomvar($model, $(GraphPPL.fquote(varexp)), $(arguments...); $(kwarguments...)))
1313
end
1414

15-
function write_datavar_expression(::ReactiveMPBackend, model, varexpr, type, arguments)
16-
return :($varexpr = ReactiveMP.datavar($model, $(GraphPPL.fquote(varexpr)), ReactiveMP.PointMass{ GraphPPL.ensure_type($(type)) }, $(arguments...)))
15+
function write_datavar_expression(::ReactiveMPBackend, model, varexpr, type, arguments, kwarguments)
16+
return :($varexpr = ReactiveMP.datavar($model, $(GraphPPL.fquote(varexpr)), ReactiveMP.PointMass{ GraphPPL.ensure_type($(type)) }, $(arguments...); $(kwarguments...)))
1717
end
1818

19-
function write_constvar_expression(::ReactiveMPBackend, model, varexpr, arguments)
20-
return :($varexpr = ReactiveMP.constvar($model, $(GraphPPL.fquote(varexpr)), $(arguments...)))
19+
function write_constvar_expression(::ReactiveMPBackend, model, varexpr, arguments, kwarguments)
20+
return :($varexpr = ReactiveMP.constvar($model, $(GraphPPL.fquote(varexpr)), $(arguments...); $(kwarguments...)))
2121
end
2222

2323
function write_as_variable(::ReactiveMPBackend, model, varexpr)
@@ -138,4 +138,27 @@ function write_fconstraint_option(form, variables, fconstraint)
138138
else
139139
error("Invalid factorisation constraint: $fconstraint")
140140
end
141+
end
142+
143+
##
144+
145+
function write_randomvar_options(::ReactiveMPBackend, variable, options)
146+
return map(options) do option
147+
@capture(option, name_Symbol = value_) || error("Invalid variable options specification: $option. Should be in a form of 'name = value'")
148+
return option
149+
end
150+
end
151+
152+
function write_constvar_options(::ReactiveMPBackend, variable, options)
153+
return map(options) do option
154+
@capture(option, name_Symbol = value_) || error("Invalid variable options specification: $option. Should be in a form of 'name = value'")
155+
return option
156+
end
157+
end
158+
159+
function write_datavar_options(::ReactiveMPBackend, variable, options)
160+
return map(options) do option
161+
@capture(option, name_Symbol = value_) || error("Invalid variable options specification: $option. Should be in a form of 'name = value'")
162+
return option
163+
end
141164
end

0 commit comments

Comments
 (0)