Skip to content

Commit 9262e9c

Browse files
authored
Merge pull request #8 from biaslab/develop-constraints
Constraints and meta specification language
2 parents c224500 + 683a8b5 commit 9262e9c

File tree

12 files changed

+915
-335
lines changed

12 files changed

+915
-335
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ Coverage.ipynb
4646
**/.DS_Store
4747

4848
examples/*Compiled
49+
statprof

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphPPL"
22
uuid = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
33
authors = ["Dmitry Bagaev <[email protected]>"]
4-
version = "1.0.5"
4+
version = "1.1.0"
55

66
[deps]
77
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ makedocs(
66
sitename = "GraphPPL.jl",
77
pages = [
88
"Home" => "index.md",
9-
"User guide" => "user-guide.md"
9+
"User guide" => "user-guide.md",
10+
"Utils" => "utils.md"
1011
],
1112
format = Documenter.HTML(
1213
prettyurls = get(ENV, "CI", nothing) == "true"

docs/src/utils.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Utils
2+
3+
```@docs
4+
GraphPPL.ishead
5+
GraphPPL.isblock
6+
GraphPPL.iscall
7+
```

src/GraphPPL.jl

Lines changed: 5 additions & 333 deletions
Original file line numberDiff line numberDiff line change
@@ -1,342 +1,14 @@
11
module GraphPPL
22

3-
export @model
4-
5-
import MacroTools
6-
import MacroTools: @capture, postwalk, prewalk, walk
7-
8-
function conditioned_walk(f, condition_skip, condition_apply, x)
9-
walk(x, x -> condition_skip(x) ? x : condition_apply(x) ? f(x) : conditioned_walk(f, condition_skip, condition_apply, x), identity)
10-
end
11-
12-
"""
13-
fquote(expr)
14-
15-
This function forces `Expr` or `Symbol` to be quoted.
16-
"""
17-
fquote(expr::Symbol) = Expr(:quote, expr)
18-
fquote(expr::Int) = expr
19-
fquote(expr::Expr) = expr
20-
21-
"""
22-
ensure_type
23-
"""
24-
ensure_type(x::Type) = x
25-
ensure_type(x) = error("Valid type object was expected but '$x' has been found")
26-
27-
is_kwargs_expression(x) = false
28-
is_kwargs_expression(x::Expr) = x.head === :parameters
29-
30-
"""
31-
parse_varexpr(varexpr)
32-
33-
This function parses variable id and returns a tuple of 3 different representations of the same variable
34-
1. Original expression
35-
2. Short variable identificator (used in variables lookup table)
36-
3. Full variable identificator (used in model as a variable id)
37-
"""
38-
function parse_varexpr(varexpr::Symbol)
39-
varexpr = varexpr
40-
short_id = varexpr
41-
full_id = varexpr
42-
return varexpr, short_id, full_id
43-
end
44-
45-
function parse_varexpr(varexpr::Expr)
46-
47-
# TODO: It might be handy to have this feature in the future for e.g. interacting with UnPack.jl package
48-
# TODO: For now however we fallback to a more informative error message since it is not obvious how to parse such expressions yet
49-
@capture(varexpr, (tupled_ids__, )) &&
50-
error("Multiple variable declarations, definitions and assigments are forbidden within @model macro. Try to split $(varexpr) into several independent statements.")
51-
52-
@capture(varexpr, id_[idx__]) ||
53-
error("Variable identificator can be in form of a single symbol (x ~ ...) or indexing expression (x[i] ~ ...)")
54-
55-
varexpr = varexpr
56-
short_id = id
57-
full_id = Expr(:call, :Symbol, fquote(id), Expr(:quote, :_), Expr(:quote, Symbol(join(idx, :_))))
58-
59-
return varexpr, short_id, full_id
60-
end
61-
62-
"""
63-
normalize_tilde_arguments(args)
64-
65-
This function 'normalizes' every argument of a tilde expression making every inner function call to be a tilde expression as well.
66-
It forces MSL to create anonymous node for any non-linear variable transformation or deterministic relationships. MSL does not check (and cannot in general)
67-
if some inner function call leads to a constant expression or not (e.g. `Normal(0.0, sqrt(10.0))`). Backend API should decide whenever to create additional anonymous nodes
68-
for constant non-linear transformation expressions or not by analyzing input arguments.
69-
"""
70-
function normalize_tilde_arguments(args)
71-
return map(args) do arg
72-
if @capture(arg, id_[idx_])
73-
return :($(__normalize_arg(id))[$idx])
74-
else
75-
return __normalize_arg(arg)
76-
end
77-
end
78-
end
79-
80-
function __normalize_arg(arg)
81-
if @capture(arg, (f_(v__) where { options__ }) | (f_(v__)))
82-
if f === :(|>)
83-
@assert length(v) === 2 "Unsupported pipe syntax in model specification: $(arg)"
84-
f = v[2]
85-
v = [ v[1] ]
86-
end
87-
nvarexpr = gensym(:nvar)
88-
nnodeexpr = gensym(:nnode)
89-
options = options !== nothing ? options : []
90-
v = normalize_tilde_arguments(v)
91-
return :(($nnodeexpr, $nvarexpr) ~ $f($(v...); $(options...)); $nvarexpr)
92-
else
93-
return arg
94-
end
95-
end
96-
97-
argument_write_default_value(arg, default::Nothing) = arg
98-
argument_write_default_value(arg, default) = Expr(:kw, arg, default)
99-
100-
101-
"""
102-
write_argument_guard(backend, argument)
103-
"""
104-
function write_argument_guard end
105-
106-
"""
107-
write_randomvar_expression(backend, model, varexpr, arguments, kwarguments)
108-
"""
109-
function write_randomvar_expression end
110-
111-
"""
112-
write_datavar_expression(backend, model, varexpr, type, arguments, kwarguments)
113-
"""
114-
function write_datavar_expression end
115-
116-
"""
117-
write_constvar_expression(backend, model, varexpr, arguments, kwarguments)
118-
"""
119-
function write_constvar_expression end
120-
121-
"""
122-
write_as_variable(backend, model, varexpr)
123-
"""
124-
function write_as_variable end
125-
126-
"""
127-
write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
128-
"""
129-
function write_make_node_expression end
130-
131-
"""
132-
write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, autovarid)
133-
"""
134-
function write_autovar_make_node_expression end
135-
136-
"""
137-
write_node_options(backend, fform, variables, options)
138-
"""
139-
function write_node_options end
140-
141-
"""
142-
write_randomvar_options(backend, variable, options)
143-
"""
144-
function write_randomvar_options end
145-
146-
"""
147-
write_constvar_options(backend, variable, options)
148-
"""
149-
function write_constvar_options end
150-
151-
"""
152-
write_datavar_options(backend, variable, options)
153-
"""
154-
function write_datavar_options end
3+
using MacroTools
1554

1565
include("backends/reactivemp.jl")
1576

1587
__get_current_backend() = ReactiveMPBackend()
1598

160-
macro model(model_specification)
161-
return esc(:(@model [] $model_specification))
162-
end
163-
164-
macro model(model_options, model_specification)
165-
return GraphPPL.generate_model_expression(__get_current_backend(), model_options, model_specification)
166-
end
167-
168-
function generate_model_expression(backend, model_options, model_specification)
169-
@capture(model_options, [ ms_options__ ]) ||
170-
error("Model specification options should be in a form of [ option1 = ..., option2 = ... ]")
171-
172-
ms_options = map(ms_options) do option
173-
(@capture(option, name_ = value_) && name isa Symbol) || error("Invalid option specification: $(option). Expected: 'option_name = option_value'.")
174-
return (name, value)
175-
end
176-
177-
ms_options = :(NamedTuple{ ($(tuple(map(first, ms_options)...))) }((($(tuple(map(last, ms_options)...)...)),)))
178-
179-
@capture(model_specification, (function ms_name_(ms_args__; ms_kwargs__) ms_body_ end) | (function ms_name_(ms_args__) ms_body_ end)) ||
180-
error("Model specification language requires full function definition")
181-
182-
model = gensym(:model)
183-
184-
ms_args_ids = Vector{Symbol}()
185-
ms_args_guard_ids = Vector{Symbol}()
186-
ms_args_const_ids = Vector{Tuple{Symbol, Symbol}}()
187-
188-
ms_arg_expression_converter = (ms_arg) -> begin
189-
if @capture(ms_arg, arg_::ConstVariable = smth_) || @capture(ms_arg, arg_::ConstVariable)
190-
# rc_arg = gensym(:constvar)
191-
push!(ms_args_const_ids, (arg, arg)) # backward compatibility for old behaviour with gensym
192-
push!(ms_args_guard_ids, arg)
193-
push!(ms_args_ids, arg)
194-
return argument_write_default_value(arg, smth)
195-
elseif @capture(ms_arg, arg_::T_ = smth_) || @capture(ms_arg, arg_::T_)
196-
push!(ms_args_guard_ids, arg)
197-
push!(ms_args_ids, arg)
198-
return argument_write_default_value(:($(arg)::$(T)), smth)
199-
elseif @capture(ms_arg, arg_Symbol = smth_) || @capture(ms_arg, arg_Symbol)
200-
push!(ms_args_guard_ids, arg)
201-
push!(ms_args_ids, arg)
202-
return argument_write_default_value(arg, smth)
203-
else
204-
error("Invalid argument specification: $(ms_arg)")
205-
end
206-
end
207-
208-
ms_args = ms_args === nothing ? [] : map(ms_arg_expression_converter, ms_args)
209-
ms_kwargs = ms_kwargs === nothing ? [] : map(ms_arg_expression_converter, ms_kwargs)
210-
211-
if length(Set(ms_args_ids)) !== length(ms_args_ids)
212-
error("There are duplicates in argument specification list: $(ms_args_ids)")
213-
end
214-
215-
ms_args_const_init_block = map(ms_args_const_ids) do ms_arg_const_id
216-
return write_constvar_expression(backend, model, first(ms_arg_const_id), [ last(ms_arg_const_id) ], [])
217-
end
218-
219-
# Step 0: Check that all inputs are not AbstractVariables
220-
# It is highly recommended not to create AbstractVariables outside of the model creation macro
221-
# Doing so can lead to undefined behaviour
222-
ms_args_checks = map((ms_arg) -> write_argument_guard(backend, ms_arg), ms_args_guard_ids)
223-
224-
# Step 1: Probabilistic arguments normalisation
225-
ms_body = prewalk(ms_body) do expression
226-
if @capture(expression, (varexpr_ ~ fform_(arguments__) where { options__ }) | (varexpr_ ~ fform_(arguments__)))
227-
options = options === nothing ? [] : options
228-
229-
# Filter out keywords arguments to options array
230-
arguments = filter(arguments) do arg
231-
ifparameters = arg isa Expr && arg.head === :parameters
232-
if ifparameters
233-
foreach(a -> push!(options, a), arg.args)
234-
end
235-
return !ifparameters
236-
end
237-
238-
varexpr = @capture(varexpr, (nodeid_, varid_)) ? varexpr : :(($(gensym(:nnode)), $varexpr))
239-
return :($varexpr ~ $(fform)($((normalize_tilde_arguments(arguments))...); $(options...)))
240-
elseif @capture(expression, varexpr_ = randomvar(arguments__) where { options__ })
241-
return :($varexpr = randomvar($(arguments...); $(write_randomvar_options(backend, varexpr, options)...)))
242-
elseif @capture(expression, varexpr_ = datavar(arguments__) where { options__ })
243-
return :($varexpr = datavar($(arguments...); $(write_datavar_options(backend, varexpr, options)...)))
244-
elseif @capture(expression, varexpr_ = constvar(arguments__) where { options__ })
245-
return :($varexpr = constvar($(arguments...); $(write_constvar_options(backend, varexpr, options)...)))
246-
elseif @capture(expression, varexpr_ = randomvar(arguments__))
247-
return :($varexpr = randomvar($(arguments...); ))
248-
elseif @capture(expression, varexpr_ = datavar(arguments__))
249-
return :($varexpr = datavar($(arguments...); ))
250-
elseif @capture(expression, varexpr_ = constvar(arguments__))
251-
return :($varexpr = constvar($(arguments...); ))
252-
else
253-
return expression
254-
end
255-
end
256-
257-
bannedids = Set{Symbol}()
258-
259-
ms_body = postwalk(ms_body) do expression
260-
if @capture(expression, lhs_ = rhs_)
261-
if !(@capture(rhs, datavar(args__))) && !(@capture(rhs, randomvar(args__))) && !(@capture(rhs, constvar(args__)))
262-
varexpr, short_id, full_id = parse_varexpr(lhs)
263-
push!(bannedids, short_id)
264-
end
265-
end
266-
return expression
267-
end
268-
269-
varids = Set{Symbol}(ms_args_ids)
270-
271-
# Step 2: Main pass
272-
ms_body = postwalk(ms_body) do expression
273-
# Step 2.1 Convert datavar calls
274-
if @capture(expression, varexpr_ = datavar(arguments__; kwarguments__))
275-
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
276-
@assert length(arguments) >= 1 "datavar() call requires type specification as a first argument"
277-
278-
push!(varids, varexpr)
279-
280-
type_argument = arguments[1]
281-
tail_arguments = arguments[2:end]
282-
283-
return write_datavar_expression(backend, model, varexpr, type_argument, tail_arguments, kwarguments)
284-
# Step 2.2 Convert randomvar calls
285-
elseif @capture(expression, varexpr_ = randomvar(arguments__; kwarguments__))
286-
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
287-
push!(varids, varexpr)
288-
289-
return write_randomvar_expression(backend, model, varexpr, arguments, kwarguments)
290-
# Step 2.3 Conver constvar calls
291-
elseif @capture(expression, varexpr_ = constvar(arguments__; kwarguments__))
292-
@assert varexpr varids "Invalid model specification: '$varexpr' id is duplicated"
293-
push!(varids, varexpr)
294-
295-
return write_constvar_expression(backend, model, varexpr, arguments, kwarguments)
296-
# Step 2.2 Convert tilde expressions
297-
elseif @capture(expression, (nodeexpr_, varexpr_) ~ fform_(arguments__; kwarguments__))
298-
# println(expression)
299-
varexpr, short_id, full_id = parse_varexpr(varexpr)
300-
301-
if short_id bannedids
302-
error("Invalid name '$(short_id)' for new random variable. '$(short_id)' was already initialized with '=' operator before.")
303-
end
304-
305-
variables = map((argexpr) -> write_as_variable(backend, model, argexpr), arguments)
306-
options = write_node_options(backend, fform, [ varexpr, arguments... ], kwarguments)
307-
308-
if short_id varids
309-
return write_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr)
310-
else
311-
push!(varids, short_id)
312-
return write_autovar_make_node_expression(backend, model, fform, variables, options, nodeexpr, varexpr, full_id)
313-
end
314-
else
315-
return expression
316-
end
317-
end
318-
319-
# Step 3: Final pass
320-
final_pass_exceptions = (x) -> @capture(x, (some_ -> body_) | (function some_(args__) body_ end) | (some_(args__) = body_))
321-
final_pass_target = (x) -> @capture(x, return ret_)
322-
323-
ms_body = conditioned_walk(final_pass_exceptions, final_pass_target, ms_body) do expression
324-
@capture(expression, return ret_) ? quote activate!($model); return $model, ($ret) end : expression
325-
end
326-
327-
res = quote
328-
329-
function $ms_name($(ms_args...); $(ms_kwargs...), options = $(ms_options))
330-
$(ms_args_checks...)
331-
options = merge($(ms_options), options)
332-
$model = Model(options)
333-
$(ms_args_const_init_block...)
334-
$ms_body
335-
error("'return' statement is missing")
336-
end
337-
end
338-
339-
return esc(res)
340-
end
9+
include("utils.jl")
10+
include("model.jl")
11+
include("constraints.jl")
12+
include("meta.jl")
34113

34214
end # module

0 commit comments

Comments
 (0)