Skip to content

Commit b3b97da

Browse files
committed
Allow user to return things from inside pobserve
1 parent afe3ba2 commit b3b97da

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

src/pobserve_macro.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ function _pobserve(expr::Expr)
2323
$(process_tilde_statements(block))
2424
end
2525
end
26-
total_likelihoods = sum(fetch.(likelihood_tasks))
26+
retvals_and_likelihoods = fetch.(likelihood_tasks)
27+
total_likelihoods = sum(last, retvals_and_likelihoods)
2728
# println("Total likelihoods: ", total_likelihoods)
2829
$(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)(
2930
$(esc(:(__varinfo__))), total_likelihoods
3031
)
31-
nothing
32+
map(first, retvals_and_likelihoods)
3233
end
3334
return return_expr
3435
end
@@ -50,16 +51,34 @@ function process_tilde_statements(expr::Expr)
5051
@gensym loglike
5152
beginning_statement =
5253
:($loglike = zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__))))))
53-
transformed_statements = map(statements) do stmt
54-
# skip non-tilde statements
55-
# TODO: dot-tilde
56-
@capture(stmt, lhs_ ~ rhs_) || return :($(esc(stmt)))
57-
# if the above matched, we transform the tilde statement
58-
# TODO: We should probably perform some checks to make sure that this
59-
# indeed was meant to be an observe statement.
60-
:($loglike += $(Distributions.logpdf)($(esc(rhs)), $(esc(lhs))))
54+
n_statements = length(statements)
55+
transformed_statements::Vector{Vector{Expr}} = map(enumerate(statements)) do (i, stmt)
56+
is_last = i == n_statements
57+
if @capture(stmt, lhs_ ~ rhs_)
58+
# TODO: We should probably perform some checks to make sure that this
59+
# indeed was meant to be an observe statement.
60+
@gensym left
61+
e = quote
62+
$left = $(esc(lhs))
63+
$loglike += $(Distributions.logpdf)($(esc(rhs)), $left)
64+
end
65+
is_last && push!(e.args, :(($left, $loglike)))
66+
e.args
67+
elseif @capture(stmt, lhs_ .~ rhs_)
68+
@gensym val
69+
e = [
70+
# TODO: dot-tilde
71+
:($val = $(esc(stmt))),
72+
]
73+
is_last && push!(e, :(($val, $loglike)))
74+
e
75+
else
76+
@gensym val
77+
e = [:($val = $(esc(stmt)))]
78+
is_last && push!(e, :(($val, $loglike)))
79+
e
80+
end
6181
end
62-
ending_statement = loglike
63-
new_statements = [beginning_statement, transformed_statements..., ending_statement]
82+
new_statements = [beginning_statement, reduce(vcat, transformed_statements)...]
6483
return Expr(:block, new_statements...)
6584
end

0 commit comments

Comments
 (0)