Skip to content

Commit 8e7d164

Browse files
committed
added dot_tilde_assume overloads for FixedContext to handle the
cases where current `fix` is failiing
1 parent 1d211c5 commit 8e7d164

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

src/context_implementations.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,63 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
304304
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
305305
end
306306

307-
function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
307+
function dot_tilde_assume(rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi)
308308
return dot_tilde_assume(
309309
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
310310
)
311311
end
312312

313+
# `FixedContext`
314+
function dot_tilde_assume(context::FixedContext, right, left, vns, vi)
315+
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
316+
# So we need to check each of the vns.
317+
logp = 0
318+
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
319+
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
320+
# then be compiled away in cases where the `Symbol` is not present.
321+
left_bc = Broadcast.broadcastable(left)
322+
right_bc = Broadcast.broadcastable(right)
323+
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
324+
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
325+
vn = vns[I_left...]
326+
if hasfixed(context, vn)
327+
left[I_left...] = getfixed(context, vn)
328+
else
329+
# Defer to `tilde_assume`.
330+
left[I_left...], logp_inner, vi = tilde_assume(context, right_bc[I_right...], vn, vi)
331+
logp += logp_inner
332+
end
333+
end
334+
end
335+
336+
return left, logp, vi
337+
end
338+
339+
function dot_tilde_assume(rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi)
340+
# If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
341+
# So we need to check each of the vns.
342+
logp = 0
343+
# TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
344+
# If the `Symbol` is not present, we can just skip this check completely. Such a check can
345+
# then be compiled away in cases where the `Symbol` is not present.
346+
left_bc = Broadcast.broadcastable(left)
347+
right_bc = Broadcast.broadcastable(right)
348+
for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...)
349+
for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...)
350+
vn = vns[I_left...]
351+
if hasfixed(context, vn)
352+
left[I_left...] = getfixed(context, vn)
353+
else
354+
# Defer to `tilde_assume`.
355+
left[I_left...], logp_inner, vi = tilde_assume(rng, context, sampler, right_bc[I_right...], vn, vi)
356+
logp += logp_inner
357+
end
358+
end
359+
end
360+
361+
return left, logp, vi
362+
end
363+
313364
"""
314365
dot_tilde_assume!!(context, right, left, vn, vi)
315366

0 commit comments

Comments
 (0)