Skip to content

Commit 0b7ba4b

Browse files
committed
added static checking to avoid the slow fixed branches unless we
really need to
1 parent 86fe1c6 commit 0b7ba4b

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/context_implementations.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,13 @@ end
312312

313313
# `FixedContext`
314314
function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
315+
if !has_fixed_symbol(context, first(vns))
316+
# Defer to `childcontext`.
317+
return tilde_assume(childcontext(context), right, left, vns, vi)
318+
end
319+
315320
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
316-
# So we need to check each of the vns.
321+
# We _might_ also have some of the variables fixed, but not all.
317322
logp = 0
318323
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
319324
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
@@ -327,7 +332,9 @@ function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
327332
left[I_left...] = getfixed(context, vn)
328333
else
329334
# Defer to `tilde_assume`.
330-
left[I_left...], logp_inner, vi = tilde_assume(context, right_bc[I_right...], vn, vi)
335+
left[I_left...], logp_inner, vi = tilde_assume(
336+
childcontext(context), right_bc[I_right...], vn, vi
337+
)
331338
logp += logp_inner
332339
end
333340
end
@@ -336,7 +343,14 @@ function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
336343
return left, logp, vi
337344
end
338345

339-
function dot_tilde_assume(rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi)
346+
function dot_tilde_assume(
347+
rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi
348+
)
349+
350+
if !has_fixed_symbol(context, first(vns))
351+
# Defer to `childcontext`.
352+
return tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi)
353+
end
340354
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
341355
# So we need to check each of the vns.
342356
logp = 0
@@ -352,7 +366,9 @@ function dot_tilde_assume(rng::Random.AbstractRNG, context::FixedContext, sample
352366
left[I_left...] = getfixed(context, vn)
353367
else
354368
# Defer to `tilde_assume`.
355-
left[I_left...], logp_inner, vi = tilde_assume(rng, context, sampler, right_bc[I_right...], vn, vi)
369+
left[I_left...], logp_inner, vi = tilde_assume(
370+
rng, childcontext(context), sampler, right_bc[I_right...], vn, vi
371+
)
356372
logp += logp_inner
357373
end
358374
end

src/contexts.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,13 @@ NodeTrait(::FixedContext) = IsParent()
501501
childcontext(context::FixedContext) = context.context
502502
setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child)
503503

504+
has_fixed_symbol(context::FixedContext, vn::VarName) = has_symbol(context.values, vn)
505+
506+
has_symbol(d::AbstractDict, vn::VarName) = haskey(d, vn)
507+
@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names, sym}
508+
return sym in names
509+
end
510+
504511
"""
505512
hasfixed(context::AbstractContext, vn::VarName)
506513

0 commit comments

Comments
 (0)