Skip to content

Commit 8a1e441

Browse files
authored
Merge pull request #14 from biaslab/dev-2.1.0
Merge 2.1.0 branch into master branch
2 parents e4dbb7e + 7283085 commit 8a1e441

File tree

7 files changed

+207
-30
lines changed

7 files changed

+207
-30
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphPPL"
22
uuid = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
33
authors = ["Dmitry Bagaev <[email protected]>"]
4-
version = "2.0.1"
4+
version = "2.1.0"
55

66
[deps]
77
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

docs/src/user-guide.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,64 @@ node, y ~ NormalMeanVariance(mean, var)
177177

178178
Having a node reference can be useful in case the user wants to return it from a model and to use it later on to specify initial joint marginal distributions.
179179

180+
### Broadcasting syntax
181+
182+
!!! note
183+
Broadcasting syntax requires at least v2.1.0 of `GraphPPL.jl`
184+
185+
GraphPPL support broadcasting for `~` operator in the exact same way as Julia itself. A user is free to write an expression of the following form:
186+
187+
```julia
188+
y = datavar(Float64, n)
189+
y .~ NormalMeanVariance(0.0, 1.0) # <- i.i.d observations
190+
```
191+
192+
More complex expression are also allowed:
193+
194+
```julia
195+
m ~ NormalMeanPrecision(0.0, 0.0001)
196+
t ~ Gamma(1.0, 1.0)
197+
198+
y = randomvar(Float64, n)
199+
y .~ NormalMeanPrecision(m, t)
200+
```
201+
202+
```julia
203+
A = constvar(...)
204+
x = randomvar(n)
205+
y = datavar(Vector{Float64}, n)
206+
207+
w ~ Wishart(3, diageye(2))
208+
x[1] ~ MvNormalMeanPrecision(zeros(2), diageye(2))
209+
x[2:end] .~ A .* x[1:end-1] # <- State-space model with transition matrix A
210+
y .~ MvNormalMeanPrecision(x, w) # <- Observations with unknown precision matrix
211+
```
212+
213+
Note, however, that all variables that take part in the broadcasting operation must be defined before either with `randomvar` or `datavar`. The exception here is constants that are automatically converted to their `constvar` equivalent. If you want to prevent broadcasting for some constant (e.g. if you want to add a vector to a multivariate Gaussian distribution) use explicit `constvar` call:
214+
215+
```julia
216+
# Suppose `x` is a 2-dimensional Gaussian distribution
217+
z .~ x .+ constvar([ 1, 1 ])
218+
# Which is equivalent to
219+
for i in 1:n
220+
z[i] ~ x[i] + constvar([ 1, 1 ])
221+
end
222+
```
223+
224+
Without explicit `constvar` Julia's broadcasting machinery would instead attempt to unroll for loop in the following way:
225+
226+
```julia
227+
# Without explicit `constvar`
228+
z .~ x .+ [ 1, 1 ]
229+
# Which is equivalent to
230+
array = [1, 1]
231+
for i in 1:n
232+
z[i] ~ x[i] + array[i] # This is wrong if `x[i]` is supposed to be a multivariate Gaussian
233+
end
234+
```
235+
236+
Read more about how broadcasting machinery works in Julia in [the official documentation](https://docs.julialang.org/en/v1/manual/arrays/#Broadcasting).
237+
180238
### Node creation options
181239

182240
To pass optional arguments to the node creation constructor the user can use the `where { options... }` options specification syntax.

docs/src/utils.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
GraphPPL.ishead
55
GraphPPL.isblock
66
GraphPPL.iscall
7+
GraphPPL.isbroadcastedcall
78
```

src/backends/reactivemp.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,27 @@ function write_argument_guard(::ReactiveMPBackend, argument::Symbol)
8383
end
8484

8585
function write_randomvar_expression(::ReactiveMPBackend, model, varexp, options, arguments)
86-
return :($varexp = ReactiveMP.randomvar($model, $options, $(GraphPPL.fquote(varexp)), $(arguments...)))
86+
return :($varexp = ReactiveMP.randomvar($model, $options, $(GraphPPL.fquote(varexp)), $(arguments...)); $varexp)
8787
end
8888

8989
function write_datavar_expression(::ReactiveMPBackend, model, varexpr, options, type, arguments)
9090
errstr = "The expression `$varexpr = datavar($(type))` is incorrect. datavar(::Type, [ dims... ]) requires `Type` as a first argument, but `$(type)` is not a `Type`."
9191
checktype = :(GraphPPL.ensure_type($(type)) || error($errstr))
92-
return :($checktype; $varexpr = ReactiveMP.datavar($model, $options, $(GraphPPL.fquote(varexpr)), ReactiveMP.PointMass{ $type }, $(arguments...)))
92+
return :($checktype; $varexpr = ReactiveMP.datavar($model, $options, $(GraphPPL.fquote(varexpr)), ReactiveMP.PointMass{ $type }, $(arguments...)); $varexpr)
9393
end
9494

9595
function write_constvar_expression(::ReactiveMPBackend, model, varexpr, arguments)
96-
return :($varexpr = ReactiveMP.constvar($model, $(GraphPPL.fquote(varexpr)), $(arguments...)))
96+
return :($varexpr = ReactiveMP.constvar($model, $(GraphPPL.fquote(varexpr)), $(arguments...)); $varexpr)
9797
end
9898

9999
function write_as_variable(::ReactiveMPBackend, model, varexpr)
100100
return :(ReactiveMP.as_variable($model, $varexpr))
101101
end
102102

103+
function write_undo_as_variable(::ReactiveMPBackend, varexpr)
104+
return :(ReactiveMP.undo_as_variable($varexpr))
105+
end
106+
103107
function write_anonymous_variable(::ReactiveMPBackend, model, varexpr)
104108
return :(ReactiveMP.setanonymous!($varexpr, true))
105109
end
@@ -108,10 +112,18 @@ function write_make_node_expression(::ReactiveMPBackend, model, fform, variables
108112
return :($nodeexpr = ReactiveMP.make_node($model, $options, $fform, $varexpr, $(variables...)))
109113
end
110114

115+
function write_broadcasted_make_node_expression(::ReactiveMPBackend, model, fform, variables, options, nodeexpr, varexpr)
116+
return :($nodeexpr = ReactiveMP.make_node.($model, $options, $fform, $varexpr, $(variables...)))
117+
end
118+
111119
function write_autovar_make_node_expression(::ReactiveMPBackend, model, fform, variables, options, nodeexpr, varexpr, autovarid)
112120
return :(($nodeexpr, $varexpr) = ReactiveMP.make_node($model, $options, $fform, ReactiveMP.AutoVar($(GraphPPL.fquote(autovarid))), $(variables...)))
113121
end
114122

123+
function write_check_variable_existence(::ReactiveMPBackend, model, varid, errormsg)
124+
return :(ReactiveMP.haskey($model, $(QuoteNode(varid))) || Base.error($errormsg))
125+
end
126+
115127
function write_node_options(::ReactiveMPBackend, model, fform, variables, options)
116128
is_factorisation_option_present = false
117129
is_meta_option_present = false

src/model.jl

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@ function normalize_tilde_arguments(backend, model, args)
6969
end
7070

7171
function __normalize_arg(backend, model, arg)
72-
if @capture(arg, (f_(v__) where { options__ }) | (f_(v__)))
72+
if @capture(arg, constvar(arguments__))
73+
return write_constvar_expression(backend, model, gensym(:anonymous_constvar), arguments)
74+
elseif @capture(arg, constvar.(arguments__))
75+
return error("Broadcasting of `constvar` in the constvar.(...) expression is dissalowed. Use `constvar((i) -> ..., dims...)` form instead.")
76+
elseif @capture(arg, (f_(v__) where { options__ }) | (f_(v__)) | (f_.(v__) where { options__ }) | (f_.(v__) ))
7377
if f === :(|>)
7478
@assert length(v) === 2 "Unsupported pipe syntax in model specification: $(arg)"
7579
f = v[2]
@@ -79,7 +83,31 @@ function __normalize_arg(backend, model, arg)
7983
nnodeexpr = gensym(:nnode)
8084
options = options !== nothing ? options : []
8185
v = normalize_tilde_arguments(backend, model, v)
82-
return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $(write_anonymous_variable(backend, model, nvarexpr)); $nvarexpr)
86+
if isbroadcastedcall(arg)
87+
# Strip dot call from broadcasting dot operators, like `.+` and define `BroadcastFunction` explicitly to avoid UndefVarError
88+
f = first(string(f)) === '.' ? Symbol(string(f)[2:end]) : f
89+
# broadcasting variables
90+
broadcasting_locals = map((_) -> gensym(:bv), v)
91+
return quote
92+
# Here we manually unroll anonymous broadcasting calls
93+
# Later on GraphPPL does not distinguish between local broadcasting `~` expression and a regular `~` expression
94+
begin
95+
Base.broadcast($(v...)) do $(broadcasting_locals...)
96+
# $initf
97+
($nnodeexpr, $nvarexpr) ~ $f($(broadcasting_locals...); $(options...));
98+
$(write_anonymous_variable(backend, model, nvarexpr));
99+
$(write_undo_as_variable(backend, nvarexpr));
100+
end
101+
end
102+
end
103+
104+
else
105+
return quote
106+
($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...));
107+
$(write_anonymous_variable(backend, model, nvarexpr));
108+
$(write_undo_as_variable(backend, nvarexpr));
109+
end
110+
end
83111
else
84112
return arg
85113
end
@@ -128,6 +156,11 @@ function write_constvar_expression end
128156
"""
129157
function write_as_variable end
130158

159+
"""
160+
write_undo_as_variable(backend, varexpr)
161+
"""
162+
function write_undo_as_variable end
163+
131164
"""
132165
write_anonymous_variable(backend, model, varexpr)
133166
"""
@@ -138,11 +171,21 @@ function write_anonymous_variable end
138171
"""
139172
function write_make_node_expression end
140173

174+
"""
175+
write_broadcasted_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
176+
"""
177+
function write_broadcasted_make_node_expression end
178+
141179
"""
142180
write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, autovarid)
143181
"""
144182
function write_autovar_make_node_expression end
145183

184+
"""
185+
write_check_variable_existence(backend, model, varid, errormsg)
186+
"""
187+
function write_check_variable_existence end
188+
146189
"""
147190
write_node_options(backend, model, fform, variables, options)
148191
"""
@@ -237,7 +280,10 @@ function generate_model_expression(backend, model_options, model_specification)
237280

238281
# Step 1: Probabilistic arguments normalisation
239282
ms_body = prewalk(ms_body) do expression
240-
if @capture(expression, (varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__)))
283+
if @capture(expression,
284+
(varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__)) |
285+
(varexpr_ .~ fform_(arguments__) where { options__ }) | (varexpr_ .~ fform_(arguments__))
286+
)
241287
options = options === nothing ? [] : options
242288

243289
# Filter out keywords arguments to options array
@@ -249,8 +295,9 @@ function generate_model_expression(backend, model_options, model_specification)
249295
return !ifparameters
250296
end
251297

252-
varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr))
253-
return :($varexpr ~ $(fform)($((normalize_tilde_arguments(backend, model, arguments))...); $(options...)))
298+
varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr))
299+
operator = isbroadcastedcall(expression) ? Symbol(".~") : :(~)
300+
return :($operator($varexpr, $(fform)($((normalize_tilde_arguments(backend, model, arguments))...); $(options...))))
254301
elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ })
255302
return :($varexpr = randomvar($(arguments...); $(options...)))
256303
elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ })
@@ -263,6 +310,8 @@ function generate_model_expression(backend, model_options, model_specification)
263310
return :($varexpr = datavar($(arguments...); ))
264311
elseif @capture(expression, varexpr_ = constvar(arguments__))
265312
return :($varexpr = constvar($(arguments...)))
313+
elseif @capture(expression, constvar.(arguments__))
314+
error("Broadcasting of `constvar` in the constvar.(...) expression is dissalowed. Use `constvar((i) -> ..., dims...)` form instead.")
266315
else
267316
return expression
268317
end
@@ -279,17 +328,12 @@ function generate_model_expression(backend, model_options, model_specification)
279328
end
280329
return expression
281330
end
282-
283-
varids = Set{Symbol}(ms_args_ids)
284331

285332
# Step 2: Main pass
286333
ms_body = postwalk(ms_body) do expression
287334
# Step 2.1 Convert datavar calls
288335
if @capture(expression, varexpr_ = datavar(arguments__; options__))
289-
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
290336
@assert length(arguments) >= 1 "The expression `$expression` is incorrect. datavar(::Type, [ dims... ]) requires `Type` as a first argument."
291-
292-
push!(varids, varexpr)
293337

294338
type_argument = arguments[1]
295339
tail_arguments = arguments[2:end]
@@ -298,35 +342,40 @@ function generate_model_expression(backend, model_options, model_specification)
298342
return write_datavar_expression(backend, model, varexpr, dvoptions, type_argument, tail_arguments)
299343
# Step 2.2 Convert randomvar calls
300344
elseif @capture(expression, varexpr_ = randomvar(arguments__; options__))
301-
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
302-
push!(varids, varexpr)
303-
304345
rvoptions = write_randomvar_options(backend, varexpr, options)
305-
306346
return write_randomvar_expression(backend, model, varexpr, rvoptions, arguments)
307-
# Step 2.3 Conver constvar calls
347+
# Step 2.3 Convert constvar calls
308348
elseif @capture(expression, varexpr_ = constvar(arguments__))
309-
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
310-
push!(varids, varexpr)
311-
312349
return write_constvar_expression(backend, model, varexpr, arguments)
313350
# Step 2.2 Convert tilde expressions
314-
elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__))
315-
# println(expression)
351+
elseif @capture(expression, ((nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__)) | ((nodeexpr_, varexpr_) .~ fform_(arguments__; kwarguments__)))
352+
316353
varexpr, short_id, full_id = parse_varexpr(varexpr)
317354

318355
if short_id bannedids
319-
error("Invalid name '$(short_id)' for new random variable. '$(short_id)' was already initialized with '=' operator before.")
356+
error("Invalid name '$(short_id)' for new random variable. '$(short_id)' has been already initialized with '=' operator.")
320357
end
321358

322359
variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments)
323360
options = write_node_options(backend, model, fform, [ varexpr, arguments... ], kwarguments)
324-
325-
if short_id varids
326-
return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
361+
362+
if isbroadcastedcall(expression)
363+
# Strip dot call from broadcasting dot operators, like `.+`
364+
fform = first(string(fform)) === '.' ? Symbol(string(fform)[2:end]) : fform
365+
return quote
366+
# In case of broadcasted call we assume that variable has been created before otherwise it should throw an error
367+
$(write_check_variable_existence(backend, model, short_id, "Cannot use variables named `$(short_id)` in the broadcasting call. `$(short_id)` sequence of variables must be created in advance."))
368+
$(write_broadcasted_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr))
369+
end
327370
else
328-
push!(varids, short_id)
329-
return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id)
371+
# Indexed variables like `y[1]` cannot be created on the fly and should be pre-initialised with `y = randomvar(n)`
372+
# Single variables like `y` can be created on the fly with the `AutoVar` marker
373+
# In the second case if variable `y` has been initialised before `AutoVar` should simply return it
374+
if isref(varexpr)
375+
return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
376+
else
377+
return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id)
378+
end
330379
end
331380
else
332381
return expression

src/utils.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,32 @@ See also: [`ishead`](@ref)
3232
iscall(expr) = ishead(expr, :call) && length(expr.args) >= 1
3333
iscall(expr, fsym) = iscall(expr) && first(expr.args) === fsym
3434

35+
"""
36+
isbroadcastedcall(expr)
37+
isbroadcastedcall(expr, fsym)
38+
39+
Checks if expression represents a broadcast call to some function. Optionally accepts `fsym` to check for exact function name match.
40+
41+
See also: [`iscall`](@ref)
42+
"""
43+
function isbroadcastedcall(expr)
44+
if isblock(expr) # TODO add for other functions?
45+
nextexpr = findnext(isexpr, expr.args, 1)
46+
return nextexpr !== nothing ? isbroadcastedcall(expr.args[nextexpr]) : false
47+
end
48+
(iscall(expr) && length(expr.args) >= 1 && first(string(first(expr.args))) === '.') || # Checks for `:(a .+ b)` syntax
49+
(ishead(expr, :(.))) # Checks for `:(f.(x))` syntax
50+
end
51+
52+
function isbroadcastedcall(expr, fsym)
53+
if isblock(expr) # TODO add for other functions?
54+
nextexpr = findnext(isexpr, expr.args, 1)
55+
return nextexpr !== nothing ? isbroadcastedcall(expr.args[nextexpr], fsym) : false
56+
end
57+
(iscall(expr) && length(expr.args) >= 1 && first(string(first(expr.args))) === '.' && Symbol(string(first(expr.args))[2:end]) === fsym) || # Checks for `:(a .+ b)` syntax
58+
(ishead(expr, :(.)) && first(expr.args) === fsym) # Checks for `:(f.(x))` syntax
59+
end
60+
3561
"""
3662
isref(expr)
3763
@@ -41,6 +67,15 @@ See also: [`ishead`](@ref)
4167
"""
4268
isref(expr) = ishead(expr, :ref)
4369

70+
"""
71+
getref(expr)
72+
73+
Returns ref indices from `expr` in a form of a tuple.
74+
75+
See als: [`isref`](@ref)
76+
"""
77+
getref(expr) = isref(expr) ? (view(expr.args, 2:lastindex(expr.args))...,) : ()
78+
4479
"""
4580
ensure_type(x)
4681

test/utils.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,28 @@ end
5454

5555
end
5656

57+
@testset "isbroadcastedcall tests" begin
58+
import GraphPPL: isbroadcastedcall
59+
60+
@test isbroadcastedcall(:(f(1))) === false
61+
@test isbroadcastedcall(:(f(1)), :f) === false
62+
@test isbroadcastedcall(:(f(1)), :g) === false
63+
@test isbroadcastedcall(:(if true 1 else 2 end)) === false
64+
@test isbroadcastedcall(:(begin end)) === false
65+
66+
@test isbroadcastedcall(:(a .+ b)) === true
67+
@test isbroadcastedcall(:(a .+ b), :(+)) === true
68+
@test isbroadcastedcall(:(a .+ b), :(-)) === false
69+
70+
@test isbroadcastedcall(:(f.(a))) === true
71+
@test isbroadcastedcall(:(f.(a, b))) === true
72+
@test isbroadcastedcall(:(f.(a)), :f) === true
73+
@test isbroadcastedcall(:(f.(a, b)), :f) === true
74+
@test isbroadcastedcall(:(f.(a)), :g) === false
75+
@test isbroadcastedcall(:(f.(a, b)), :g) === false
76+
77+
end
78+
5779
@testset "ensure_type tests" begin
5880
import GraphPPL: ensure_type
5981

0 commit comments

Comments
 (0)