Skip to content

Commit 0527a84

Browse files
committed
Make {dot}_tilde_{assume,observe} function behave more uniformly
1 parent b7159cb commit 0527a84

File tree

2 files changed

+33
-54
lines changed

2 files changed

+33
-54
lines changed

src/compiler.jl

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ function generate_mainbody!(found, expr::Expr, args)
208208
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
209209
end
210210

211-
# """ Unbreak code highlighting in Emacs julia-mode
212211

213212

214213
"""
@@ -218,7 +217,7 @@ Generate an `observe` expression for data variables and `assume` expression for
218217
variables.
219218
"""
220219
function generate_tilde(left, right, args)
221-
@gensym tmpright tmpleft
220+
@gensym tmpright
222221
top = [:($tmpright = $right),
223222
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
224223
|| throw(ArgumentError($DISTMSG)))]
@@ -227,48 +226,32 @@ function generate_tilde(left, right, args)
227226
@gensym out vn inds
228227
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
229228

230-
assumption = [
231-
:($out = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
232-
_varinfo)),
233-
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
234-
:($left = $out[1])
235-
]
236-
237229
# It can only be an observation if the LHS is an argument of the model
238230
if vsym(left) in args
239231
@gensym isassumption
240232
return quote
241233
$(top...)
242234
$isassumption = $(DynamicPPL.isassumption(left))
243235
if $isassumption
244-
$(assumption...)
236+
$left = $(DynamicPPL.tilde_assume)(
237+
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
245238
else
246-
$tmpleft = $left
247-
$(DynamicPPL.acclogp!)(
248-
_varinfo,
249-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
250-
$vn, $inds, _varinfo)
251-
)
252-
$tmpleft
239+
$(DynamicPPL.tilde_observe)(
240+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
253241
end
254242
end
255243
end
256244

257245
return quote
258246
$(top...)
259-
$(assumption...)
247+
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds, _varinfo)
260248
end
261249
end
262250

263251
# If the LHS is a literal, it is always an observation
264252
return quote
265253
$(top...)
266-
$tmpleft = $left
267-
$(DynamicPPL.acclogp!)(
268-
_varinfo,
269-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft, _varinfo)
270-
)
271-
$tmpleft
254+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
272255
end
273256
end
274257

@@ -278,7 +261,7 @@ end
278261
Generate the expression that replaces `left .~ right` in the model body.
279262
"""
280263
function generate_dot_tilde(left, right, args)
281-
@gensym tmpright tmpleft
264+
@gensym tmpright
282265
top = [:($tmpright = $right),
283266
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
284267
|| throw(ArgumentError($DISTMSG)))]
@@ -287,49 +270,33 @@ function generate_dot_tilde(left, right, args)
287270
@gensym out vn inds
288271
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
289272

290-
assumption = [
291-
:($out = $(DynamicPPL.dot_tilde_assume)(_context, _sampler, $tmpright, $left,
292-
$vn, $inds, _varinfo)),
293-
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
294-
:($left .= $out[1])
295-
]
296-
297273
# It can only be an observation if the LHS is an argument of the model
298274
if vsym(left) in args
299275
@gensym isassumption
300276
return quote
301277
$(top...)
302278
$isassumption = $(DynamicPPL.isassumption(left))
303279
if $isassumption
304-
$(assumption...)
280+
$left .= $(DynamicPPL.dot_tilde_assume)(
281+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
305282
else
306-
$tmpleft = $left
307-
$(DynamicPPL.acclogp!)(
308-
_varinfo,
309-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright,
310-
$tmpleft, $vn, $inds, _varinfo)
311-
)
312-
$tmpleft
283+
$(DynamicPPL.dot_tilde_observe)(
284+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
313285
end
314286
end
315287
end
316288

317289
return quote
318290
$(top...)
319-
$(assumption...)
291+
$left .= $(DynamicPPL.dot_tilde_assume)(
292+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
320293
end
321294
end
322295

323296
# If the LHS is a literal, it is always an observation
324297
return quote
325298
$(top...)
326-
$tmpleft = $left
327-
$(DynamicPPL.acclogp!)(
328-
_varinfo,
329-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
330-
_varinfo)
331-
)
332-
$tmpleft
299+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
333300
end
334301
end
335302

src/context_implementations.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ This method is applied in the generated code for assumed variables, e.g., `x ~ N
4444
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4545
"""
4646
function tilde_assume(ctx, sampler, right, vn, inds, vi)
47-
return tilde(ctx, sampler, right, vn, inds, vi)
47+
(value, logp) = tilde(ctx, sampler, right, vn, inds, vi)
48+
acclogp!(vi, logp)
49+
return value
4850
end
4951

5052

@@ -79,7 +81,9 @@ Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information ab
7981
name and indices; if needed, these can be accessed through this function, though.
8082
"""
8183
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
82-
return tilde(ctx, sampler, right, left, vi)
84+
logp = tilde(ctx, sampler, right, left, vi)
85+
acclogp!(vi, logp)
86+
return left
8387
end
8488

8589
"""
@@ -89,7 +93,9 @@ This method is applied in the generated code for observed constants, e.g., `1.0
8993
Falls back to `tilde(ctx, sampler, right, left, vi)`.
9094
"""
9195
function tilde_observe(ctx, sampler, right, left, vi)
92-
return tilde(ctx, sampler, right, left, vi)
96+
logp = tilde(ctx, sampler, right, left, vi)
97+
acclogp!(vi, logp)
98+
return left
9399
end
94100

95101

@@ -197,7 +203,9 @@ MvNormal()` where `x` does not occur in the model inputs.
197203
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
198204
"""
199205
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
200-
return dot_tilde(ctx, sampler, right, left, vn, inds, vi)
206+
(value, logp) = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
207+
acclogp!(vi, logp)
208+
return value
201209
end
202210

203211

@@ -374,7 +382,9 @@ Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the informatio
374382
name and indices; if needed, these can be accessed through this function, though.
375383
"""
376384
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
377-
return dot_tilde(ctx, sampler, right, left, vi)
385+
logp = dot_tilde(ctx, sampler, right, left, vi)
386+
acclogp!(vi, logp)
387+
return left
378388
end
379389

380390
"""
@@ -384,7 +394,9 @@ This method is applied in the generated code for vectorized observed constants,
384394
MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
385395
"""
386396
function dot_tilde_observe(ctx, sampler, right, left, vi)
387-
return dot_tilde(ctx, sampler, right, left, vi)
397+
logp = dot_tilde(ctx, sampler, right, left, vi)
398+
acclogp!(vi, logp)
399+
return left
388400
end
389401

390402

0 commit comments

Comments
 (0)