Skip to content

Commit 7c1ff74

Browse files
committed
Merge branch 'develop'
2 parents 6f3dbd3 + c9fe1ad commit 7c1ff74

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

src/GraphPPL.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ fquote(expr::Expr) = expr
2424
ensure_type(x::Type) = x
2525
ensure_type(x) = error("Valid type object was expected but '$x' has been found")
2626

27+
is_kwargs_expression(x) = false
28+
is_kwargs_expression(x::Expr) = x.head === :parameters
29+
2730
"""
2831
parse_varexpr(varexpr)
2932
@@ -91,6 +94,10 @@ function __normalize_arg(arg)
9194
end
9295
end
9396

97+
argument_write_default_value(arg, default::Nothing) = arg
98+
argument_write_default_value(arg, default) = Expr(:kw, arg, default)
99+
100+
94101
"""
95102
write_argument_guard(backend, argument)
96103
"""
@@ -154,39 +161,42 @@ function generate_model_expression(backend, model_options, model_specification)
154161

155162
ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),)))
156163

157-
@capture(model_specification, function ms_name_(ms_args__) ms_body_ end) ||
164+
@capture(model_specification, (function ms_name_(ms_args__; ms_kwargs__) ms_body_ end) | (function ms_name_(ms_args__) ms_body_ end)) ||
158165
error("Model specification language requires full function definition")
159166

160167
model = gensym(:model)
161168

162-
ms_args_const_ids = filter(ms_args) do ms_arg
163-
@capture(ms_arg, var_::ConstVariable)
164-
end
165-
166169
ms_args_ids = Vector{Symbol}()
167170
ms_args_guard_ids = Vector{Symbol}()
168171
ms_args_const_ids = Vector{Tuple{Symbol, Symbol}}()
169172

170-
ms_args = map(ms_args) do ms_arg
171-
if @capture(ms_arg, arg_::ConstVariable)
172-
rc_arg = gensym(:constvar)
173-
push!(ms_args_const_ids, (arg, rc_arg))
174-
push!(ms_args_guard_ids, rc_arg)
173+
ms_arg_expression_converter = (ms_arg) -> begin
174+
if @capture(ms_arg, arg_::ConstVariable = smth_) || @capture(ms_arg, arg_::ConstVariable)
175+
# rc_arg = gensym(:constvar)
176+
push!(ms_args_const_ids, (arg, arg)) # backward compatibility for old behaviour with gensym
177+
push!(ms_args_guard_ids, arg)
175178
push!(ms_args_ids, arg)
176-
return rc_arg
177-
elseif @capture(ms_arg, arg_::T_)
179+
return argument_write_default_value(arg, smth)
180+
elseif @capture(ms_arg, arg_::T_ = smth_) || @capture(ms_arg, arg_::T_)
178181
push!(ms_args_guard_ids, arg)
179182
push!(ms_args_ids, arg)
180-
return ms_arg
181-
elseif @capture(ms_arg, arg_Symbol)
183+
return argument_write_default_value(:($(arg)::$(T)), smth)
184+
elseif @capture(ms_arg, arg_Symbol = smth_) || @capture(ms_arg, arg_Symbol)
182185
push!(ms_args_guard_ids, arg)
183186
push!(ms_args_ids, arg)
184-
return ms_arg
187+
return argument_write_default_value(arg, smth)
185188
else
186189
error("Invalid argument specification: $(ms_arg)")
187190
end
188191
end
189192

193+
ms_args = ms_args === nothing ? [] : map(ms_arg_expression_converter, ms_args)
194+
ms_kwargs = ms_kwargs === nothing ? [] : map(ms_arg_expression_converter, ms_kwargs)
195+
196+
if length(Set(ms_args_ids)) !== length(ms_args_ids)
197+
error("There are duplicates in argument specification list: $(ms_args_ids)")
198+
end
199+
190200
ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id
191201
return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ])
192202
end
@@ -289,7 +299,7 @@ function generate_model_expression(backend, model_options, model_specification)
289299

290300
res = quote
291301

292-
function $ms_name($(ms_args...); options = $(ms_options))
302+
function $ms_name($(ms_args...); $(ms_kwargs...), options = $(ms_options))
293303
$(ms_args_checks...)
294304
options = merge($(ms_options), options)
295305
$model = Model(options)

src/backends/reactivemp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ end
5858

5959
function write_pipeline_option(fform, fpipeline)
6060
if @capture(fpipeline, +(stages__))
61-
return :(pipeline = ReactiveMP.FactorNodePipeline(+($(map(stage -> write_pipeline_stage(fform, stage), stages)...))))
61+
return :(pipeline = +($(map(stage -> write_pipeline_stage(fform, stage), stages)...)))
6262
else
63-
return :(pipeline = ReactiveMP.FactorNodePipeline($(write_pipeline_stage(fform, fpipeline))))
63+
return :(pipeline = $(write_pipeline_stage(fform, fpipeline)))
6464
end
6565
end
6666

0 commit comments

Comments
 (0)