Skip to content

Commit 06a88a0

Browse files
authored
Merge pull request #18 from biaslab/dev-node-specification-aliases
Feature: aliases and keyword distribution specification in the model macro
2 parents 125d744 + d700c12 commit 06a88a0

File tree

4 files changed

+160
-8
lines changed

4 files changed

+160
-8
lines changed

src/backends/reactivemp.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,4 +455,87 @@ end
455455

456456
function write_meta_specification_entry(::ReactiveMPBackend, F, N, meta)
457457
return :(ReactiveMP.MetaSpecificationEntry(Val($F), Val($N), $meta))
458+
end
459+
460+
# Aliases
461+
462+
ReactiveMPNodeAliases = (
463+
(
464+
(expression) -> @capture(expression, a_ || b_) ? :(ReactiveMP.OR($a, $b)) : expression,
465+
"`a || b`: alias for `OR(a, b)` node (operator precedence between `||`, `&&`, `->` and `!` is the same as in Julia)."
466+
),
467+
(
468+
(expression) -> @capture(expression, a_ && b_) ? :(ReactiveMP.AND($a, $b)) : expression,
469+
"`a && b`: alias for `AND(a, b)` node (operator precedence `||`, `&&`, `->` and `!` is the same as in Julia)."
470+
),
471+
(
472+
(expression) -> @capture(expression, a_ -> b_) ? :(ReactiveMP.IMPLY($a, $b)) : expression,
473+
"`a -> b`: alias for `IMPLY(a, b)` node (operator precedence `||`, `&&`, `->` and `!` is the same as in Julia)."
474+
),
475+
(
476+
(expression) -> @capture(expression, (¬a_) | (!a_)) ? :(ReactiveMP.NOT($a)) : expression,
477+
"`¬a` and `!a`: alias for `NOT(a)` node (Unicode `\\neg`, operator precedence `||`, `&&`, `->` and `!` is the same as in Julia)."
478+
),
479+
(
480+
(expression) -> @capture(expression, +(args__)) ? fold_linear_operator_call(expression) : expression,
481+
"`a + b + c`: alias for `(a + b) + c`"
482+
),
483+
(
484+
(expression) -> @capture(expression, *(args__)) ? fold_linear_operator_call(expression) : expression,
485+
"`a * b * c`: alias for `(a * b) * c`"
486+
),
487+
(
488+
(expression) -> @capture(expression, (Normal | Gaussian)((μ)|(m)|(mean) = mean_, (σ²)|(τ⁻¹)|(v)|(var)|(variance) = var_)) ? :(NormalMeanVariance($mean, $var)) : expression,
489+
"`Normal(μ|m|mean = ..., σ²|τ⁻¹|v|var|variance = ...)` alias for `NormalMeanVariance(..., ...)` node. `Gaussian` could be used instead `Normal` too."
490+
),
491+
(
492+
(expression) -> @capture(expression, (Normal | Gaussian)((μ)|(m)|(mean) = mean_, (τ)|(γ)|(σ⁻²)|(w)|(p)|(prec)|(precision) = prec_)) ? :(NormalMeanPrecision($mean, $prec)) : expression,
493+
"`Normal(μ|m|mean = ..., τ|γ|σ⁻²|w|p|prec|precision = ...)` alias for `NormalMeanVariance(..., ...)` node. `Gaussian` could be used instead `Normal` too."
494+
),
495+
(
496+
(expression) -> @capture(expression, (MvNormal | MvGaussian)((μ)|(m)|(mean) = mean_, (Σ)|(V)|(Λ⁻¹)|(cov)|(covariance) = cov_)) ? :(MvNormalMeanCovariance($mean, $cov)) : expression,
497+
"`MvNormal(μ|m|mean = ..., Σ|V|Λ⁻¹|cov|covariance = ...)` alias for `MvNormalMeanCovariance(..., ...)` node. `MvGaussian` could be used instead `MvNormal` too."
498+
),
499+
(
500+
(expression) -> @capture(expression, (MvNormal | MvGaussian)((μ)|(m)|(mean) = mean_, (Λ)|(W)|(Σ⁻¹)|(prec)|(precision) = prec_)) ? :(MvNormalMeanPrecision($mean, $prec)) : expression,
501+
"`MvNormal(μ|m|mean = ..., Λ|W|Σ⁻¹|prec|precision = ...)` alias for `MvNormalMeanPrecision(..., ...)` node. `MvGaussian` could be used instead `MvNormal` too."
502+
),
503+
((expression) -> @capture(expression, (Normal | Gaussian)(args__)) ? :(error("Please use a specific version of the `Normal` (`Gaussian`) distribution (e.g. `NormalMeanVariance` or aliased version `Normal(mean = ..., variance|precision = ...)`).")) : expression, missing),
504+
((expression) -> @capture(expression, (MvNormal | MvGaussian)(args__)) ? :(error("Please use a specific version of the `MvNormal` (`MvGaussian`) distribution (e.g. `MvNormalMeanCovariance` or aliased version `MvNormal(mean = ..., covariance|precision = ...)`).")) : expression, missing),
505+
(
506+
(expression) -> @capture(expression, Gamma((α)|(a)|(shape) = shape_, (θ)|(β⁻¹)|(scale) = scale_)) ? :(GammaShapeScale($shape, $scale)) : expression,
507+
"`Gamma(α|a|shape = ..., θ|β⁻¹|scale = ...)` alias for `GammaShapeScale(..., ...) node.`"
508+
),
509+
(
510+
(expression) -> @capture(expression, Gamma((α)|(a)|(shape) = shape_, (β)|(θ⁻¹)|(rate) = rate_)) ? :(GammaShapeRate($shape, $rate)) : expression,
511+
"`Gamma(α|a|shape = ..., β|θ⁻¹|rate = ...)` alias for `GammaShapeRate(..., ...) node.`"
512+
),
513+
)
514+
515+
function show_tilderhs_alias(::ReactiveMPBackend, io = stdout)
516+
foreach(skipmissing(map(last, ReactiveMPNodeAliases))) do alias
517+
println(io, "- ", alias)
518+
end
519+
end
520+
521+
function apply_alias_transformation(notanexpression, alias)
522+
# We always short-circuit on non-expression
523+
return (notanexpression, true)
524+
end
525+
526+
function apply_alias_transformation(expression::Expr, alias)
527+
_expression = first(alias)(expression)
528+
# Returns potentially modified expression and a Boolean flag,
529+
# which indicates if expression actually has been modified
530+
return (_expression, _expression !== expression)
531+
end
532+
533+
function write_inject_tilderhs_aliases(::ReactiveMPBackend, model, tilderhs)
534+
return postwalk(tilderhs) do expression
535+
# We short-circuit if `mflag` is true
536+
_expression, _ = foldl(ReactiveMPNodeAliases; init = (expression, false)) do (expression, mflag), alias
537+
return mflag ? (expression, true) : apply_alias_transformation(expression, alias)
538+
end
539+
return _expression
540+
end
458541
end

src/model.jl

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,31 @@ function write_default_model_constraints end
211211
"""
212212
function write_default_model_meta end
213213

214+
"""
215+
write_inject_tilderhs_aliases(backend, model, tilderhs)
216+
"""
217+
function write_inject_tilderhs_aliases end
218+
219+
"""
220+
show_tilderhs_alias(backend, io)
221+
"""
222+
function show_tilderhs_alias end
223+
224+
"""
225+
226+
```julia
227+
@model [ model_options ] function model_name(model_arguments...; model_keyword_arguments...)
228+
# model description
229+
end
230+
```
231+
232+
`@model` macro generates a function that returns an equivalent graph-representation of the given probabilistic model description.
233+
234+
## Supported alias in the model specification
235+
$(begin io = IOBuffer(); show_tilderhs_alias(__get_current_backend(), io); String(take!(io)) end)
236+
"""
237+
macro model end
238+
214239
macro model(model_specification)
215240
return esc(:(@model [] $model_specification))
216241
end
@@ -278,7 +303,18 @@ function generate_model_expression(backend, model_options, model_specification)
278303
# Doing so can lead to undefined behaviour
279304
ms_args_checks = map((ms_arg) -> write_argument_guard(backend, ms_arg), ms_args_guard_ids)
280305

281-
# Step 1: Probabilistic arguments normalisation
306+
# Step 1: Inject node's aliases
307+
ms_body = postwalk(ms_body) do expression
308+
if @capture(expression, (lhs_ ~ rhs_ where { options__ }) | (lhs_ .~ rhs_ where { options__ }))
309+
return :($lhs ~ $(write_inject_tilderhs_aliases(backend, model, rhs)) where { $options... })
310+
elseif @capture(expression, (lhs_ ~ rhs_) | (lhs_ .~ rhs_))
311+
return :($lhs ~ $(write_inject_tilderhs_aliases(backend, model, rhs)))
312+
else
313+
return expression
314+
end
315+
end
316+
317+
# Step 2: Probabilistic arguments normalisation
282318
ms_body = prewalk(ms_body) do expression
283319
if @capture(expression,
284320
(varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__)) |
@@ -329,9 +365,9 @@ function generate_model_expression(backend, model_options, model_specification)
329365
return expression
330366
end
331367

332-
# Step 2: Main pass
368+
# Step 3: Main pass
333369
ms_body = postwalk(ms_body) do expression
334-
# Step 2.1 Convert datavar calls
370+
# Step 3.1 Convert datavar calls
335371
if @capture(expression, varexpr_ = datavar(arguments__; options__))
336372
@assert length(arguments) >= 1 "The expression `$expression` is incorrect. datavar(::Type, [ dims... ]) requires `Type` as a first argument."
337373

@@ -340,14 +376,14 @@ function generate_model_expression(backend, model_options, model_specification)
340376
dvoptions = write_datavar_options(backend, varexpr, type_argument, options)
341377

342378
return write_datavar_expression(backend, model, varexpr, dvoptions, type_argument, tail_arguments)
343-
# Step 2.2 Convert randomvar calls
379+
# Step 3.2 Convert randomvar calls
344380
elseif @capture(expression, varexpr_ = randomvar(arguments__; options__))
345381
rvoptions = write_randomvar_options(backend, varexpr, options)
346382
return write_randomvar_expression(backend, model, varexpr, rvoptions, arguments)
347-
# Step 2.3 Convert constvar calls
383+
# Step 3.3 Convert constvar calls
348384
elseif @capture(expression, varexpr_ = constvar(arguments__))
349385
return write_constvar_expression(backend, model, varexpr, arguments)
350-
# Step 2.2 Convert tilde expressions
386+
# Step 3.2 Convert tilde expressions
351387
elseif @capture(expression, ((nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__)) | ((nodeexpr_, varexpr_) .~ fform_(arguments__; kwarguments__)))
352388

353389
varexpr, short_id, full_id = parse_varexpr(varexpr)
@@ -382,7 +418,7 @@ function generate_model_expression(backend, model_options, model_specification)
382418
end
383419
end
384420

385-
# Step 3: Final pass
421+
# Step 4: Final pass
386422
final_pass_exceptions = (x) -> @capture(x, (some_ -> body_) | (function some_(args__) body_ end) | (some_(args__) = body_))
387423
final_pass_target = (x) -> @capture(x, return ret_)
388424

src/utils.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,20 @@ getref(expr) = isref(expr) ? (view(expr.args, 2:lastindex(expr.args))...,) : ()
8282
Checks if `x` is of type `Type`
8383
"""
8484
ensure_type(x::Type) = true
85-
ensure_type(x) = false
85+
ensure_type(x) = false
86+
87+
fold_linear_operator_call(any) = any
88+
89+
fold_linear_operator_call_first_arg(::typeof(foldl), args) = args[begin + 1]
90+
fold_linear_operator_call_tail_args(::typeof(foldl), args) = args[begin+2:end]
91+
92+
fold_linear_operator_call_first_arg(::typeof(foldr), args) = args[end]
93+
fold_linear_operator_call_tail_args(::typeof(foldr), args) = args[begin+1:end-1]
94+
95+
function fold_linear_operator_call(expr::Expr, fold = foldl)
96+
if @capture(expr, op_(args__, )) && length(args) > 2
97+
return fold((res, el) -> Expr(:call, op, res, el), fold_linear_operator_call_tail_args(fold, expr.args); init = fold_linear_operator_call_first_arg(fold, expr.args))
98+
else
99+
return expr
100+
end
101+
end

test/utils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module UtilsTests
22

33
using Test
44
using GraphPPL
5+
using MacroTools
56

67
@testset "issymbol tests" begin
78
import GraphPPL: issymbol
@@ -85,4 +86,20 @@ end
8586
@test ensure_type(1.0) === false
8687
end
8788

89+
@testset "fold_linear_operator_call" begin
90+
import GraphPPL: fold_linear_operator_call
91+
92+
@test @capture(fold_linear_operator_call(:(+a)), +a)
93+
@test @capture(fold_linear_operator_call(:(a + b)), a + b)
94+
@test @capture(fold_linear_operator_call(:(a + b + c)), (a + b) + c)
95+
@test @capture(fold_linear_operator_call(:(a + b + c + d)), ((a + b) + c) + d)
96+
@test @capture(fold_linear_operator_call(:(a + b + c + d), foldr), (a + (b + (c + d))))
97+
98+
@test @capture(fold_linear_operator_call(:(a * b)), a * b)
99+
@test @capture(fold_linear_operator_call(:(a * b * c)), (a * b) * c)
100+
@test @capture(fold_linear_operator_call(:(a * b * c * d)), ((a * b) * c) * d)
101+
@test @capture(fold_linear_operator_call(:(a * b * c * d), foldr), (a * (b * (c * d))))
102+
103+
end
104+
88105
end

0 commit comments

Comments
 (0)