Skip to content

Commit bf35de4

Browse files
committed
added to_sampleable and limited ~ handling for submodels
1 parent 946fa6d commit bf35de4

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ export AbstractVarInfo,
129129
value_iterator_from_chain,
130130
check_model,
131131
check_model_and_trace,
132+
to_sampleable,
132133
# Deprecated.
133134
@logprob_str,
134135
@prob_str

src/compiler.jl

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
22

3+
struct SampleableModelWrapper{M}
4+
model::M
5+
end
6+
7+
to_sampleable(model::DynamicPPL.Model) = SampleableModelWrapper(model)
8+
39
"""
410
need_concretize(expr)
511
@@ -427,28 +433,34 @@ function generate_tilde(left, right)
427433
# more selective with our escape. Until that's the case, we remove them all.
428434
return quote
429435
$dist = $right
430-
$vn = $(DynamicPPL.resolve_varnames)(
431-
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
432-
)
433-
$isassumption = $(DynamicPPL.isassumption(left, vn))
434-
if $(DynamicPPL.isfixed(left, vn))
435-
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
436-
elseif $isassumption
437-
$(generate_tilde_assume(left, dist, vn))
438-
else
439-
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
440-
if !$(DynamicPPL.inargnames)($vn, __model__)
441-
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
442-
end
443436

444-
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
445-
__context__,
446-
$(DynamicPPL.check_tilde_rhs)($dist),
447-
$(maybe_view(left)),
448-
$vn,
449-
__varinfo__,
437+
if $dist isa $(SampleableModelWrapper)
438+
$left, __varinfo__ = $(_evaluate!!)($dist.model, __varinfo__, __context__)
439+
$left
440+
else
441+
$vn = $(DynamicPPL.resolve_varnames)(
442+
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
450443
)
451-
$value
444+
$isassumption = $(DynamicPPL.isassumption(left, vn))
445+
if $(DynamicPPL.isfixed(left, vn))
446+
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
447+
elseif $isassumption
448+
$(generate_tilde_assume(left, dist, vn))
449+
else
450+
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
451+
if !$(DynamicPPL.inargnames)($vn, __model__)
452+
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
453+
end
454+
455+
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
456+
__context__,
457+
$(DynamicPPL.check_tilde_rhs)($dist),
458+
$(maybe_view(left)),
459+
$vn,
460+
__varinfo__,
461+
)
462+
$value
463+
end
452464
end
453465
end
454466
end

0 commit comments

Comments
 (0)