@@ -23,12 +23,13 @@ function _pobserve(expr::Expr)
23
23
$ (process_tilde_statements (block))
24
24
end
25
25
end
26
- total_likelihoods = sum (fetch .(likelihood_tasks))
26
+ retvals_and_likelihoods = fetch .(likelihood_tasks)
27
+ total_likelihoods = sum (last, retvals_and_likelihoods)
27
28
# println("Total likelihoods: ", total_likelihoods)
28
29
$ (esc (:(__varinfo__))) = $ (DynamicPPL. accloglikelihood!!)(
29
30
$ (esc (:(__varinfo__))), total_likelihoods
30
31
)
31
- nothing
32
+ map (first, retvals_and_likelihoods)
32
33
end
33
34
return return_expr
34
35
end
@@ -50,16 +51,34 @@ function process_tilde_statements(expr::Expr)
50
51
@gensym loglike
51
52
beginning_statement =
52
53
:($ loglike = zero ($ (DynamicPPL. getloglikelihood)($ (esc (:(__varinfo__))))))
53
- transformed_statements = map (statements) do stmt
54
- # skip non-tilde statements
55
- # TODO : dot-tilde
56
- @capture (stmt, lhs_ ~ rhs_) || return :($ (esc (stmt)))
57
- # if the above matched, we transform the tilde statement
58
- # TODO : We should probably perform some checks to make sure that this
59
- # indeed was meant to be an observe statement.
60
- :($ loglike += $ (Distributions. logpdf)($ (esc (rhs)), $ (esc (lhs))))
54
+ n_statements = length (statements)
55
+ transformed_statements:: Vector{Vector{Expr}} = map (enumerate (statements)) do (i, stmt)
56
+ is_last = i == n_statements
57
+ if @capture (stmt, lhs_ ~ rhs_)
58
+ # TODO : We should probably perform some checks to make sure that this
59
+ # indeed was meant to be an observe statement.
60
+ @gensym left
61
+ e = quote
62
+ $ left = $ (esc (lhs))
63
+ $ loglike += $ (Distributions. logpdf)($ (esc (rhs)), $ left)
64
+ end
65
+ is_last && push! (e. args, :(($ left, $ loglike)))
66
+ e. args
67
+ elseif @capture (stmt, lhs_ .~ rhs_)
68
+ @gensym val
69
+ e = [
70
+ # TODO : dot-tilde
71
+ :($ val = $ (esc (stmt))),
72
+ ]
73
+ is_last && push! (e, :(($ val, $ loglike)))
74
+ e
75
+ else
76
+ @gensym val
77
+ e = [:($ val = $ (esc (stmt)))]
78
+ is_last && push! (e, :(($ val, $ loglike)))
79
+ e
80
+ end
61
81
end
62
- ending_statement = loglike
63
- new_statements = [beginning_statement, transformed_statements... , ending_statement]
82
+ new_statements = [beginning_statement, reduce (vcat, transformed_statements)... ]
64
83
return Expr (:block , new_statements... )
65
84
end
0 commit comments