Skip to content

Commit 74aefff

Browse files
committed
feat(): set anonymous randomvar
1 parent 5c70ce8 commit 74aefff

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

src/backends/reactivemp.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ function write_as_variable(::ReactiveMPBackend, model, varexpr)
9898
return :(ReactiveMP.as_variable($model, $varexpr))
9999
end
100100

101+
function write_anonymous_randomvar(::ReactiveMPBackend, model, varexpr)
102+
return :(ReactiveMP.setanonymous!($varexpr, true))
103+
end
104+
101105
function write_make_node_expression(::ReactiveMPBackend, model, fform, variables, options, nodeexpr, varexpr)
102106
return :($nodeexpr = ReactiveMP.make_node($model, $options, $fform, $varexpr, $(variables...)))
103107
end

src/model.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,24 @@ function parse_varexpr(varexpr::Expr)
5757
end
5858

5959
"""
60-
normalize_tilde_arguments(args)
60+
normalize_tilde_arguments(backend, model, args)
6161
6262
This function 'normalizes' every argument of a tilde expression making every inner function call to be a tilde expression as well.
6363
It forces MSL to create anonymous node for any non-linear variable transformation or deterministic relationships. MSL does not check (and cannot in general)
6464
if some inner function call leads to a constant expression or not (e.g. `Normal(0.0, sqrt(10.0))`). Backend API should decide whenever to create additional anonymous nodes
6565
for constant non-linear transformation expressions or not by analyzing input arguments.
6666
"""
67-
function normalize_tilde_arguments(args)
67+
function normalize_tilde_arguments(backend, model, args)
6868
return map(args) do arg
6969
if @capture(arg, id_[idx_])
70-
return :($(__normalize_arg(id))[$idx])
70+
return :($(__normalize_arg(backend, model, id))[$idx])
7171
else
72-
return __normalize_arg(arg)
72+
return __normalize_arg(backend, model, arg)
7373
end
7474
end
7575
end
7676

77-
function __normalize_arg(arg)
77+
function __normalize_arg(backend, model, arg)
7878
if @capture(arg, (f_(v__) where { options__ }) | (f_(v__)))
7979
if f === :(|>)
8080
@assert length(v) === 2 "Unsupported pipe syntax in model specification: $(arg)"
@@ -84,8 +84,8 @@ function __normalize_arg(arg)
8484
nvarexpr = gensym(:nvar)
8585
nnodeexpr = gensym(:nnode)
8686
options = options !== nothing ? options : []
87-
v = normalize_tilde_arguments(v)
88-
return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $nvarexpr)
87+
v = normalize_tilde_arguments(backend, model, v)
88+
return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $(write_anonymous_randomvar(backend, model, nvarexpr)); $nvarexpr)
8989
else
9090
return arg
9191
end
@@ -134,6 +134,11 @@ function write_constvar_expression end
134134
"""
135135
function write_as_variable end
136136

137+
"""
138+
write_anonymous_randomvar(backend, model, varexpr)
139+
"""
140+
function write_anonymous_randomvar end
141+
137142
"""
138143
write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
139144
"""
@@ -251,7 +256,7 @@ function generate_model_expression(backend, model_options, model_specification)
251256
end
252257

253258
varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr))
254-
return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...)))
259+
return :($varexpr ~ $(fform)($((normalize_tilde_arguments(backend, model, arguments))...); $(options...)))
255260
elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ })
256261
return :($varexpr = randomvar($(arguments...); $(options...)))
257262
elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ })

0 commit comments

Comments
 (0)