Skip to content

Commit 55b09f1

Browse files
authored
Merge pull request #57 from TuringLang/bugfix
Fix error `Model not defined`
2 parents e0f68b5 + 1c5087a commit 55b09f1

File tree

4 files changed

+154
-159
lines changed

4 files changed

+154
-159
lines changed

src/compiler.jl

Lines changed: 141 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -11,59 +11,42 @@ function _error_msg()
1111
return "This macro is only for use in the `@model` macro and not for external use."
1212
end
1313

14-
15-
16-
# Check if the right-hand side is a distribution.
17-
function assert_dist(dist; msg)
18-
isa(dist, Distribution) || throw(ArgumentError(msg))
19-
end
20-
function assert_dist(dist::AbstractVector; msg)
21-
all(d -> isa(d, Distribution), dist) || throw(ArgumentError(msg))
22-
end
23-
24-
function wrong_dist_errormsg(l)
25-
return "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
26-
"Distributions on line $(l)."
27-
end
14+
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
15+
"Distributions."
2816

2917
"""
30-
@isassumption(model, expr)
18+
isassumption(model, expr)
19+
20+
Return an expression that can be evaluated to check if `expr` is an assumption in the
21+
`model`.
3122
32-
Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
33-
1. `x` was not among the input data to the model,
34-
2. `x` was among the input data to the model but with a value `missing`, or
35-
3. `x` was among the input data to the model with a value other than missing,
23+
Let `expr` be `:(x[1])`. It is an assumption in the following cases:
24+
1. `x` is not among the input data to the `model`,
25+
2. `x` is among the input data to the `model` but with a value `missing`, or
26+
3. `x` is among the input data to the `model` with a value other than missing,
3627
but `x[1] === missing`.
28+
3729
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3830
"""
39-
macro isassumption(model, expr::Union{Symbol, Expr})
40-
# Note: never put a return in this... don't forget it's a macro!
31+
function isassumption(model, expr::Union{Symbol, Expr})
4132
vn = gensym(:vn)
42-
33+
4334
return quote
44-
$vn = @varname($expr)
45-
46-
# This branch should compile nicely in all cases except for partial missing data
47-
# For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48-
if !DynamicPPL.inargnames($vn, $model) || DynamicPPL.inmissings($vn, $model)
49-
true
50-
else
51-
if DynamicPPL.inargnames($vn, $model)
52-
# Evaluate the lhs
53-
$expr === missing
35+
let $vn = $(varname(expr))
36+
# This branch should compile nicely in all cases except for partial missing data
37+
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
38+
if !$(DynamicPPL.inargnames)($vn, $model) || $(DynamicPPL.inmissings)($vn, $model)
39+
true
5440
else
55-
throw("This point should not be reached. Please report this error.")
41+
# Evaluate the LHS
42+
$expr === missing
5643
end
5744
end
58-
end |> esc
59-
end
60-
61-
macro isassumption(model, expr)
62-
# failsafe: a literal is never an assumption
63-
false
45+
end
6446
end
6547

66-
48+
# failsafe: a literal is never an assumption
49+
isassumption(model, expr) = :(false)
6750

6851
#################
6952
# Main Compiler #
@@ -128,7 +111,7 @@ function build_model_info(input_expr)
128111
Expr(:tuple, QuoteNode.(arg_syms)...),
129112
Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in arg_syms]...)
130113
)
131-
args_nt = Expr(:call, :(DynamicPPL.namedtuple), nt_type, Expr(:tuple, arg_syms...))
114+
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
132115
end
133116
args = map(modeldef[:args]) do arg
134117
if (arg isa Symbol)
@@ -217,7 +200,7 @@ function replace_logpdf!(model_info)
217200
vi = model_info[:main_body_names][:vi]
218201
ex = MacroTools.postwalk(ex) do x
219202
if @capture(x, @logpdf())
220-
:($vi.logp[])
203+
:($(vi).logp[])
221204
else
222205
x
223206
end
@@ -261,14 +244,14 @@ function replace_tilde!(model_info)
261244
dotargs = getargs_dottilde(x)
262245
if dotargs !== nothing
263246
L, R = dotargs
264-
return generate_dot_tilde(L, R, model_info)
247+
return Base.remove_linenums!(generate_dot_tilde(L, R, model_info))
265248
end
266249

267250
# Check tilde.
268251
args = getargs_tilde(x)
269252
if args !== nothing
270253
L, R = args
271-
return generate_tilde(L, R, model_info)
254+
return Base.remove_linenums!(generate_tilde(L, R, model_info))
272255
end
273256

274257
return x
@@ -294,45 +277,55 @@ function generate_tilde(left, right, model_info)
294277
vi = model_info[:main_body_names][:vi]
295278
ctx = model_info[:main_body_names][:ctx]
296279
sampler = model_info[:main_body_names][:sampler]
297-
temp_right = gensym(:temp_right)
298-
out = gensym(:out)
299-
lp = gensym(:lp)
300-
vn = gensym(:vn)
301-
inds = gensym(:inds)
302-
isassumption = gensym(:isassumption)
303-
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
304-
280+
281+
@gensym tmpright
282+
top = [:($tmpright = $right),
283+
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
284+
|| throw(ArgumentError($DISTMSG)))]
285+
305286
if left isa Symbol || left isa Expr
306-
ex = quote
307-
$temp_right = $right
308-
$assert_ex
309-
310-
$vn, $inds = $(varname(left)), $(vinds(left))
311-
$isassumption = DynamicPPL.@isassumption($model, $left)
312-
if $isassumption
313-
$out = DynamicPPL.tilde_assume($ctx, $sampler, $temp_right, $vn, $inds, $vi)
314-
$left = $out[1]
315-
DynamicPPL.acclogp!($vi, $out[2])
316-
else
317-
DynamicPPL.acclogp!(
318-
$vi,
319-
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
320-
)
287+
@gensym out vn inds
288+
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
289+
290+
assumption = [
291+
:($out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
292+
$vi)),
293+
:($left = $out[1]),
294+
:($(DynamicPPL.acclogp!)($vi, $out[2]))
295+
]
296+
297+
# It can only be an observation if the LHS is an argument of the model
298+
if vsym(left) in model_info[:args]
299+
@gensym isassumption
300+
return quote
301+
$(top...)
302+
$isassumption = $(DynamicPPL.isassumption(model, left))
303+
if $isassumption
304+
$(assumption...)
305+
else
306+
$(DynamicPPL.acclogp!)(
307+
$vi,
308+
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vn,
309+
$inds, $vi)
310+
)
311+
end
321312
end
322313
end
323-
else
324-
# we have a literal, which is automatically an observation
325-
ex = quote
326-
$temp_right = $right
327-
$assert_ex
328-
329-
DynamicPPL.acclogp!(
330-
$vi,
331-
DynamicPPL.tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
332-
)
314+
315+
return quote
316+
$(top...)
317+
$(assumption...)
333318
end
334319
end
335-
return ex
320+
321+
# If the LHS is a literal, it is always an observation
322+
return quote
323+
$(top...)
324+
$(DynamicPPL.acclogp!)(
325+
$vi,
326+
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
327+
)
328+
end
336329
end
337330

338331
"""
@@ -347,46 +340,55 @@ function generate_dot_tilde(left, right, model_info)
347340
vi = model_info[:main_body_names][:vi]
348341
ctx = model_info[:main_body_names][:ctx]
349342
sampler = model_info[:main_body_names][:sampler]
350-
out = gensym(:out)
351-
temp_right = gensym(:temp_right)
352-
isassumption = gensym(:isassumption)
353-
lp = gensym(:lp)
354-
vn = gensym(:vn)
355-
inds = gensym(:inds)
356-
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
357-
343+
344+
@gensym tmpright
345+
top = [:($tmpright = $right),
346+
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
347+
|| throw(ArgumentError($DISTMSG)))]
348+
358349
if left isa Symbol || left isa Expr
359-
ex = quote
360-
$temp_right = $right
361-
$assert_ex
362-
363-
$vn, $inds = $(varname(left)), $(vinds(left))
364-
$isassumption = DynamicPPL.@isassumption($model, $left)
365-
366-
if $isassumption
367-
$out = DynamicPPL.dot_tilde_assume($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi)
368-
$left .= $out[1]
369-
DynamicPPL.acclogp!($vi, $out[2])
370-
else
371-
DynamicPPL.acclogp!(
372-
$vi,
373-
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vn, $inds, $vi),
374-
)
350+
@gensym out vn inds
351+
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
352+
353+
assumption = [
354+
:($out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
355+
$vn, $inds, $vi)),
356+
:($left .= $out[1]),
357+
:($(DynamicPPL.acclogp!)($vi, $out[2]))
358+
]
359+
360+
# It can only be an observation if the LHS is an argument of the model
361+
if vsym(left) in model_info[:args]
362+
@gensym isassumption
363+
return quote
364+
$(top...)
365+
$isassumption = $(DynamicPPL.isassumption(model, left))
366+
if $isassumption
367+
$(assumption...)
368+
else
369+
$(DynamicPPL.acclogp!)(
370+
$vi,
371+
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left,
372+
$vn, $inds, $vi)
373+
)
374+
end
375375
end
376376
end
377-
else
378-
# we have a literal, which is automatically an observation
379-
ex = quote
380-
$temp_right = $right
381-
$assert_ex
382-
383-
DynamicPPL.acclogp!(
384-
$vi,
385-
DynamicPPL.dot_tilde_observe($ctx, $sampler, $temp_right, $left, $vi),
386-
)
377+
378+
return quote
379+
$(top...)
380+
$(assumption...)
387381
end
388382
end
389-
return ex
383+
384+
# If the LHS is a literal, it is always an observation
385+
return quote
386+
$(top...)
387+
$(DynamicPPL.acclogp!)(
388+
$vi,
389+
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
390+
)
391+
end
390392
end
391393

392394
const FloatOrArrayType = Type{<:Union{AbstractFloat, AbstractArray}}
@@ -425,42 +427,29 @@ function build_output(model_info)
425427

426428
unwrap_data_expr = Expr(:block)
427429
for var in arg_syms
428-
temp_var = gensym(:temp_var)
429-
varT = gensym(:varT)
430-
push!(unwrap_data_expr.args, quote
431-
local $var
432-
$temp_var = $model.args.$var
433-
$varT = typeof($temp_var)
434-
if $temp_var isa DynamicPPL.FloatOrArrayType
435-
$var = DynamicPPL.get_matching_type($sampler, $vi, $temp_var)
436-
elseif DynamicPPL.hasmissing($varT)
437-
$var = DynamicPPL.get_matching_type($sampler, $vi, $varT)($temp_var)
438-
else
439-
$var = $temp_var
440-
end
441-
end)
430+
push!(unwrap_data_expr.args,
431+
:($var = $(DynamicPPL.matchingvalue)($sampler, $vi, $(model).args.$var)))
442432
end
443433

444434
@gensym(evaluator, generator)
445435
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
446-
model_gen_constructor = :(DynamicPPL.ModelGen{$(Tuple(arg_syms))}($generator, $defaults_nt))
447-
436+
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
437+
448438
ex = quote
449439
function $evaluator(
450-
$model::Model,
451-
$vi::DynamicPPL.VarInfo,
452-
$sampler::DynamicPPL.AbstractSampler,
453-
$ctx::DynamicPPL.AbstractContext,
440+
$model::$(DynamicPPL.Model),
441+
$vi::$(DynamicPPL.VarInfo),
442+
$sampler::$(DynamicPPL.AbstractSampler),
443+
$ctx::$(DynamicPPL.AbstractContext),
454444
)
455445
$unwrap_data_expr
456-
DynamicPPL.resetlogp!($vi)
446+
$(DynamicPPL.resetlogp!)($vi)
457447
$main_body
458448
end
459-
460449

461-
$generator($(args...)) = DynamicPPL.Model($evaluator, $args_nt, $model_gen_constructor)
450+
$generator($(args...)) = $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor)
462451
$(generator_kw_form...)
463-
452+
464453
$model_gen = $model_gen_constructor
465454
end
466455

@@ -475,6 +464,21 @@ function warn_empty(body)
475464
return
476465
end
477466

467+
"""
468+
matchingvalue(sampler, vi, value)
469+
470+
Convert the `value` to the correct type for the `sampler` and the `vi` object.
471+
"""
472+
function matchingvalue(sampler, vi, value)
473+
T = typeof(value)
474+
if hasmissing(T)
475+
return get_matching_type(sampler, vi, T)(value)
476+
else
477+
return value
478+
end
479+
end
480+
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)
481+
478482
"""
479483
get_matching_type(spl, vi, ::Type{T}) where {T}
480484
Get the specialized version of type `T` for sampler `spl`. For example,

0 commit comments

Comments
 (0)