Skip to content

Commit 27a2df0

Browse files
committed
Get rid of one more Val
1 parent c9394fd commit 27a2df0

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/compiler.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ Otherwise, the value of `x[1]` is returned.
3939
macro preprocess(data_vars, missing_vars, ex)
4040
ex
4141
end
42-
macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
42+
macro preprocess(model, ex::Union{Symbol, Expr})
4343
sym = gensym(:sym)
4444
lhs = gensym(:lhs)
4545
return esc(quote
4646
# Extract symbol
4747
$sym = Val($(vsym(ex)))
4848
# This branch should compile nicely in all cases except for partial missing data
4949
# For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
50-
if !DynamicPPL.inparams($sym, $data_vars) || DynamicPPL.inparams($sym, $missing_vars)
50+
if !DynamicPPL.inargnames($sym, $model) || DynamicPPL.inmissings($sym, $model)
5151
$(varname(ex)), $(vinds(ex))
5252
else
53-
if DynamicPPL.inparams($sym, $data_vars)
53+
if DynamicPPL.inargnames($sym, $model)
5454
# Evaluate the lhs
5555
$lhs = $ex
5656
if $lhs === missing
@@ -64,9 +64,7 @@ macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
6464
end
6565
end)
6666
end
67-
@generated function inparams(::Val{s}, ::Val{t}) where {s, t}
68-
return (s in t) ? :(true) : :(false)
69-
end
67+
7068

7169
#################
7270
# Main Compiler #
@@ -305,7 +303,6 @@ The `tilde` function generates `observe` expression for data variables and `assu
305303
expressions for parameter variables, updating `model_info` in the process.
306304
"""
307305
function generate_tilde(left, right, model_info)
308-
arg_syms = Val((model_info[:arg_syms]...,))
309306
model = model_info[:main_body_names][:model]
310307
vi = model_info[:main_body_names][:vi]
311308
ctx = model_info[:main_body_names][:ctx]
@@ -321,7 +318,7 @@ function generate_tilde(left, right, model_info)
321318
ex = quote
322319
$temp_right = $right
323320
$assert_ex
324-
$preprocessed = DynamicPPL.@preprocess($arg_syms, Val{DynamicPPL.getmissings($model)}(), $left)
321+
$preprocessed = DynamicPPL.@preprocess($model, $left)
325322
if $preprocessed isa Tuple
326323
$vn, $inds = $preprocessed
327324
$out = DynamicPPL.tilde($ctx, $sampler, $temp_right, $vn, $inds, $vi)
@@ -353,7 +350,6 @@ end
353350
This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
354351
"""
355352
function generate_dot_tilde(left, right, model_info)
356-
arg_syms = Val((model_info[:arg_syms]...,))
357353
model = model_info[:main_body_names][:model]
358354
vi = model_info[:main_body_names][:vi]
359355
ctx = model_info[:main_body_names][:ctx]
@@ -370,7 +366,7 @@ function generate_dot_tilde(left, right, model_info)
370366
ex = quote
371367
$temp_right = $right
372368
$assert_ex
373-
$preprocessed = DynamicPPL.@preprocess($arg_syms, Val{DynamicPPL.getmissings($model)}(), $left)
369+
$preprocessed = DynamicPPL.@preprocess($model, $left)
374370
if $preprocessed isa Tuple
375371
$vn, $inds = $preprocessed
376372
$temp_left = $left

src/model.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ Get a tuple of the argument names of the `model`.
4444
getargnames(model::Model) = getargnames(typeof(model))
4545
getargnames(::Type{<:Model{_G, argnames} where {_G}}) where {argnames} = argnames
4646

47+
@generated function inargnames(::Val{s}, ::Model{_G, argnames}) where {s, _G, argnames}
48+
return s in argnames
49+
end
50+
4751

4852
"""
4953
getmissings(model::Model)
@@ -55,6 +59,10 @@ getmissings(model::Model{_G, _a, missings}) where {missings, _G, _a} = missings
5559
getmissing(model::Model) = getmissings(model)
5660
@deprecate getmissing(model) getmissings(model)
5761

62+
@generated function inmissings(::Val{s}, ::Model{_G, _a, missings}) where {s, missings, _G, _a}
63+
return s in missings
64+
end
65+
5866

5967
"""
6068
getgenerator(model::Model)

0 commit comments

Comments
 (0)