|
1 | 1 | const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) |
2 | 2 |
|
| 3 | +struct SampleableModelWrapper{M} |
| 4 | + model::M |
| 5 | +end |
| 6 | + |
| 7 | +to_sampleable(model::DynamicPPL.Model) = SampleableModelWrapper(model) |
| 8 | + |
3 | 9 | """ |
4 | 10 | need_concretize(expr) |
5 | 11 |
|
@@ -427,28 +433,34 @@ function generate_tilde(left, right) |
427 | 433 | # more selective with our escape. Until that's the case, we remove them all. |
428 | 434 | return quote |
429 | 435 | $dist = $right |
430 | | - $vn = $(DynamicPPL.resolve_varnames)( |
431 | | - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist |
432 | | - ) |
433 | | - $isassumption = $(DynamicPPL.isassumption(left, vn)) |
434 | | - if $(DynamicPPL.isfixed(left, vn)) |
435 | | - $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) |
436 | | - elseif $isassumption |
437 | | - $(generate_tilde_assume(left, dist, vn)) |
438 | | - else |
439 | | - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. |
440 | | - if !$(DynamicPPL.inargnames)($vn, __model__) |
441 | | - $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) |
442 | | - end |
443 | 436 |
|
444 | | - $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( |
445 | | - __context__, |
446 | | - $(DynamicPPL.check_tilde_rhs)($dist), |
447 | | - $(maybe_view(left)), |
448 | | - $vn, |
449 | | - __varinfo__, |
| 437 | + if $dist isa $(SampleableModelWrapper) |
| 438 | + $left, __varinfo__ = $(_evaluate!!)($dist.model, __varinfo__, __context__) |
| 439 | + $left |
| 440 | + else |
| 441 | + $vn = $(DynamicPPL.resolve_varnames)( |
| 442 | + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist |
450 | 443 | ) |
451 | | - $value |
| 444 | + $isassumption = $(DynamicPPL.isassumption(left, vn)) |
| 445 | + if $(DynamicPPL.isfixed(left, vn)) |
| 446 | + $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) |
| 447 | + elseif $isassumption |
| 448 | + $(generate_tilde_assume(left, dist, vn)) |
| 449 | + else |
| 450 | + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. |
| 451 | + if !$(DynamicPPL.inargnames)($vn, __model__) |
| 452 | + $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) |
| 453 | + end |
| 454 | + |
| 455 | + $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( |
| 456 | + __context__, |
| 457 | + $(DynamicPPL.check_tilde_rhs)($dist), |
| 458 | + $(maybe_view(left)), |
| 459 | + $vn, |
| 460 | + __varinfo__, |
| 461 | + ) |
| 462 | + $value |
| 463 | + end |
452 | 464 | end |
453 | 465 | end |
454 | 466 | end |
|
0 commit comments