@@ -304,12 +304,63 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
304
304
return dot_tilde_assume (context. context, right, left, prefix .(Ref (context), vn), vi)
305
305
end
306
306
307
- function dot_tilde_assume (rng, context:: PrefixContext , sampler, right, left, vn, vi)
307
+ function dot_tilde_assume (rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, left, vn, vi)
308
308
return dot_tilde_assume (
309
309
rng, context. context, sampler, right, left, prefix .(Ref (context), vn), vi
310
310
)
311
311
end
312
312
313
+ # `FixedContext`
314
+ function dot_tilde_assume (context:: FixedContext , right, left, vns, vi)
315
+ # If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
316
+ # So we need to check each of the vns.
317
+ logp = 0
318
+ # TODO (torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
319
+ # If the `Symbol` is not present, we can just skip this check completely. Such a check can
320
+ # then be compiled away in cases where the `Symbol` is not present.
321
+ left_bc = Broadcast. broadcastable (left)
322
+ right_bc = Broadcast. broadcastable (right)
323
+ for I_left in Iterators. product (Broadcast. broadcast_axes (left_bc)... )
324
+ for I_right in Iterators. product (Broadcast. broadcast_axes (right_bc)... )
325
+ vn = vns[I_left... ]
326
+ if hasfixed (context, vn)
327
+ left[I_left... ] = getfixed (context, vn)
328
+ else
329
+ # Defer to `tilde_assume`.
330
+ left[I_left... ], logp_inner, vi = tilde_assume (context, right_bc[I_right... ], vn, vi)
331
+ logp += logp_inner
332
+ end
333
+ end
334
+ end
335
+
336
+ return left, logp, vi
337
+ end
338
+
339
+ function dot_tilde_assume (rng:: Random.AbstractRNG , context:: FixedContext , sampler, right, left, vns, vi)
340
+ # If we're reached here, then we didn't hit the initial `getfixed` call in the model body.
341
+ # So we need to check each of the vns.
342
+ logp = 0
343
+ # TODO (torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`.
344
+ # If the `Symbol` is not present, we can just skip this check completely. Such a check can
345
+ # then be compiled away in cases where the `Symbol` is not present.
346
+ left_bc = Broadcast. broadcastable (left)
347
+ right_bc = Broadcast. broadcastable (right)
348
+ for I_left in Iterators. product (Broadcast. broadcast_axes (left_bc)... )
349
+ for I_right in Iterators. product (Broadcast. broadcast_axes (right_bc)... )
350
+ vn = vns[I_left... ]
351
+ if hasfixed (context, vn)
352
+ left[I_left... ] = getfixed (context, vn)
353
+ else
354
+ # Defer to `tilde_assume`.
355
+ left[I_left... ], logp_inner, vi = tilde_assume (rng, context, sampler, right_bc[I_right... ], vn, vi)
356
+ logp += logp_inner
357
+ end
358
+ end
359
+ end
360
+
361
+ return left, logp, vi
362
+ end
363
+
313
364
"""
314
365
dot_tilde_assume!!(context, right, left, vn, vi)
315
366
0 commit comments