Skip to content

Commit 5fbe319

Browse files
committed
Make inassumption macro a bit nicer
1 parent 989b648 commit 5fbe319

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

src/compiler.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,44 +27,44 @@ function wrong_dist_errormsg(l)
2727
end
2828

2929
"""
30-
@isassumption(data_vars, missing_vars, ex)
30+
@isassumption(model, expr)
3131
32-
Let `ex` be `x[1]`. This macro returns `true` in any of the following cases:
32+
Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
3333
1. `x` was not among the input data to the model,
3434
2. `x` was among the input data to the model but with a value `missing`, or
3535
3. `x` was among the input data to the model with a value other than missing,
3636
but `x[1] === missing`.
37-
When `ex` is not a variable (e.g., a literal), the function returns `false` as well.
37+
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3838
"""
39-
macro isassumption(data_vars, missing_vars, ex)
40-
:false
41-
end
42-
macro isassumption(model, ex::Union{Symbol, Expr})
43-
sym = gensym(:sym)
44-
lhs = gensym(:lhs)
45-
return esc(quote
46-
# Extract symbol
47-
$sym = Val($(vsym(ex)))
39+
macro isassumption(model, expr::Union{Symbol, Expr})
40+
# Note: never put a return in this... don't forget it's a macro!
41+
vn = gensym(:vn)
42+
43+
return quote
44+
$vn = @varname($expr)
45+
4846
# This branch should compile nicely in all cases except for partial missing data
49-
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
50-
if !DynamicPPL.inargnames($sym, $model) || DynamicPPL.inmissings($sym, $model)
47+
# For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48+
if !DynamicPPL.inargnames($vn, $model) || DynamicPPL.inmissings($vn, $model)
5149
true
5250
else
53-
if DynamicPPL.inargnames($sym, $model)
51+
if DynamicPPL.inargnames($vn, $model)
5452
# Evaluate the lhs
55-
$lhs = $ex
56-
if $lhs === missing
57-
true
58-
else
59-
false
60-
end
53+
$expr === missing
6154
else
6255
throw("This point should not be reached. Please report this error.")
6356
end
6457
end
65-
end)
58+
end |> esc
6659
end
6760

61+
macro isassumption(model, expr)
62+
# failsafe: a literal is never an assumption
63+
false
64+
end
65+
66+
67+
6868
#################
6969
# Main Compiler #
7070
#################
@@ -301,7 +301,7 @@ function generate_tilde(left, right, model_info)
301301
inds = gensym(:inds)
302302
isassumption = gensym(:isassumption)
303303
assert_ex = :(DynamicPPL.assert_dist($temp_right, msg = $(wrong_dist_errormsg(@__LINE__))))
304-
304+
305305
if left isa Symbol || left isa Expr
306306
ex = quote
307307
$temp_right = $right

src/model.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,6 @@ Get a tuple of the argument names of the `model`.
147147
"""
148148
getargnames(model::Model{_F, argnames}) where {argnames, _F} = argnames
149149

150-
@generated function inargnames(::Val{s}, ::Model{_F, argnames}) where {s, argnames, _F}
151-
return s in argnames
152-
end
153-
154150

155151
"""
156152
getmissings(model::Model)
@@ -162,10 +158,6 @@ getmissings(model::Model{_F, _a, _T, missings}) where {missings, _F, _a, _T} = m
162158
getmissing(model::Model) = getmissings(model)
163159
@deprecate getmissing(model) getmissings(model)
164160

165-
@generated function inmissings(::Val{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
166-
return s in missings
167-
end
168-
169161

170162
"""
171163
getgenerator(model::Model)

src/varname.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,12 @@ function split_var_str(var_str, inds_as = Vector)
134134
end
135135
return sym, inds
136136
end
137+
138+
139+
@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F}
140+
return s in argnames
141+
end
142+
143+
@generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
144+
return s in missings
145+
end

0 commit comments

Comments
 (0)