312
312
313
313
# `FixedContext`
314
314
function dot_tilde_assume (context:: FixedContext , right, left, vns, vi)
315
+ if ! has_fixed_symbol (context, first (vns))
316
+ # Defer to `childcontext`.
317
+ return tilde_assume (childcontext (context), right, left, vns, vi)
318
+ end
319
+
315
320
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
316
- # So we need to check each of the vns .
321
+ # We _might_ also have some of the variables fixed, but not all .
317
322
logp = 0
318
323
# TODO (torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
319
324
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
@@ -327,7 +332,9 @@ function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
327
332
left[I_left... ] = getfixed (context, vn)
328
333
else
329
334
# Defer to `tilde_assume`.
330
- left[I_left... ], logp_inner, vi = tilde_assume (context, right_bc[I_right... ], vn, vi)
335
+ left[I_left... ], logp_inner, vi = tilde_assume (
336
+ childcontext (context), right_bc[I_right... ], vn, vi
337
+ )
331
338
logp += logp_inner
332
339
end
333
340
end
@@ -336,7 +343,14 @@ function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
336
343
return left, logp, vi
337
344
end
338
345
339
- function dot_tilde_assume (rng:: Random.AbstractRNG , context:: FixedContext , sampler, right, left, vns, vi)
346
+ function dot_tilde_assume (
347
+ rng:: Random.AbstractRNG , context:: FixedContext , sampler, right, left, vns, vi
348
+ )
349
+
350
+ if ! has_fixed_symbol (context, first (vns))
351
+ # Defer to `childcontext`.
352
+ return tilde_assume (rng, childcontext (context), sampler, right, left, vns, vi)
353
+ end
340
354
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
341
355
# So we need to check each of the vns.
342
356
logp = 0
@@ -352,7 +366,9 @@ function dot_tilde_assume(rng::Random.AbstractRNG, context::FixedContext, sample
352
366
left[I_left... ] = getfixed (context, vn)
353
367
else
354
368
# Defer to `tilde_assume`.
355
- left[I_left... ], logp_inner, vi = tilde_assume (rng, context, sampler, right_bc[I_right... ], vn, vi)
369
+ left[I_left... ], logp_inner, vi = tilde_assume (
370
+ rng, childcontext (context), sampler, right_bc[I_right... ], vn, vi
371
+ )
356
372
logp += logp_inner
357
373
end
358
374
end
0 commit comments