Skip to content

Commit c685f7d

Browse files
authored
Merge pull request #76 from phipsgabler/phg/tilde_returns
Make tilde_{assume,observe} functions behave more uniformly
2 parents 068a887 + dfb4b74 commit c685f7d

File tree

2 files changed

+50
-67
lines changed

2 files changed

+50
-67
lines changed

src/compiler.jl

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ function generate_mainbody!(found, expr::Expr, args)
208208
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
209209
end
210210

211-
# """ Unbreak code highlighting in Emacs julia-mode
212211

213212

214213
"""
@@ -218,7 +217,7 @@ Generate an `observe` expression for data variables and `assume` expression for
218217
variables.
219218
"""
220219
function generate_tilde(left, right, args)
221-
@gensym tmpright tmpleft
220+
@gensym tmpright
222221
top = [:($tmpright = $right),
223222
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
224223
|| throw(ArgumentError($DISTMSG)))]
@@ -227,48 +226,32 @@ function generate_tilde(left, right, args)
227226
@gensym out vn inds
228227
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
229228

230-
assumption = [
231-
:($out = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
232-
_varinfo)),
233-
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
234-
:($left = $out[1])
235-
]
236-
237229
# It can only be an observation if the LHS is an argument of the model
238230
if vsym(left) in args
239231
@gensym isassumption
240232
return quote
241233
$(top...)
242234
$isassumption = $(DynamicPPL.isassumption(left))
243235
if $isassumption
244-
$(assumption...)
236+
$left = $(DynamicPPL.tilde_assume)(
237+
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
245238
else
246-
$tmpleft = $left
247-
$(DynamicPPL.acclogp!)(
248-
_varinfo,
249-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
250-
$vn, $inds, _varinfo)
251-
)
252-
$tmpleft
239+
$(DynamicPPL.tilde_observe)(
240+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
253241
end
254242
end
255243
end
256244

257245
return quote
258246
$(top...)
259-
$(assumption...)
247+
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds, _varinfo)
260248
end
261249
end
262250

263251
# If the LHS is a literal, it is always an observation
264252
return quote
265253
$(top...)
266-
$tmpleft = $left
267-
$(DynamicPPL.acclogp!)(
268-
_varinfo,
269-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft, _varinfo)
270-
)
271-
$tmpleft
254+
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
272255
end
273256
end
274257

@@ -278,7 +261,7 @@ end
278261
Generate the expression that replaces `left .~ right` in the model body.
279262
"""
280263
function generate_dot_tilde(left, right, args)
281-
@gensym tmpright tmpleft
264+
@gensym tmpright
282265
top = [:($tmpright = $right),
283266
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
284267
|| throw(ArgumentError($DISTMSG)))]
@@ -287,49 +270,33 @@ function generate_dot_tilde(left, right, args)
287270
@gensym out vn inds
288271
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))
289272

290-
assumption = [
291-
:($out = $(DynamicPPL.dot_tilde_assume)(_context, _sampler, $tmpright, $left,
292-
$vn, $inds, _varinfo)),
293-
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
294-
:($left .= $out[1])
295-
]
296-
297273
# It can only be an observation if the LHS is an argument of the model
298274
if vsym(left) in args
299275
@gensym isassumption
300276
return quote
301277
$(top...)
302278
$isassumption = $(DynamicPPL.isassumption(left))
303279
if $isassumption
304-
$(assumption...)
280+
$left .= $(DynamicPPL.dot_tilde_assume)(
281+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
305282
else
306-
$tmpleft = $left
307-
$(DynamicPPL.acclogp!)(
308-
_varinfo,
309-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright,
310-
$tmpleft, $vn, $inds, _varinfo)
311-
)
312-
$tmpleft
283+
$(DynamicPPL.dot_tilde_observe)(
284+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
313285
end
314286
end
315287
end
316288

317289
return quote
318290
$(top...)
319-
$(assumption...)
291+
$left .= $(DynamicPPL.dot_tilde_assume)(
292+
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
320293
end
321294
end
322295

323296
# If the LHS is a literal, it is always an observation
324297
return quote
325298
$(top...)
326-
$tmpleft = $left
327-
$(DynamicPPL.acclogp!)(
328-
_varinfo,
329-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
330-
_varinfo)
331-
)
332-
$tmpleft
299+
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
333300
end
334301
end
335302

src/context_implementations.jl

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ end
3838
"""
3939
tilde_assume(ctx, sampler, right, vn, inds, vi)
4040
41-
This method is applied in the generated code for assumed variables, e.g., `x ~ Normal()` where
42-
`x` does not occur in the model inputs.
41+
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
42+
accumulate the log probability, and return the sampled value.
4343
4444
Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
4545
"""
4646
function tilde_assume(ctx, sampler, right, vn, inds, vi)
47-
return tilde(ctx, sampler, right, vn, inds, vi)
47+
value, logp = tilde(ctx, sampler, right, vn, inds, vi)
48+
acclogp!(vi, logp)
49+
return value
4850
end
4951

5052

@@ -72,24 +74,30 @@ end
7274
"""
7375
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
7476
75-
This method is applied in the generated code for observed variables, e.g., `x ~ Normal()` where
76-
`x` does occur in the model inputs.
77+
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
78+
accumulate the log probability, and return the observed value.
7779
78-
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
79-
name and indices; if needed, these can be accessed through this function, though.
80+
Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
81+
and indices; if needed, these can be accessed through this function, though.
8082
"""
8183
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
82-
return tilde(ctx, sampler, right, left, vi)
84+
logp = tilde(ctx, sampler, right, left, vi)
85+
acclogp!(vi, logp)
86+
return left
8387
end
8488

8589
"""
8690
tilde_observe(ctx, sampler, right, left, vi)
8791
88-
This method is applied in the generated code for observed constants, e.g., `1.0 ~ Normal()`.
92+
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the
93+
observed value.
94+
8995
Falls back to `tilde(ctx, sampler, right, left, vi)`.
9096
"""
9197
function tilde_observe(ctx, sampler, right, left, vi)
92-
return tilde(ctx, sampler, right, left, vi)
98+
logp = tilde(ctx, sampler, right, left, vi)
99+
acclogp!(vi, logp)
100+
return left
93101
end
94102

95103

@@ -191,13 +199,15 @@ end
191199
"""
192200
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
193201
194-
This method is applied in the generated code for assumed vectorized variables, e.g., `x .~
195-
MvNormal()` where `x` does not occur in the model inputs.
202+
Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
203+
model inputs), accumulate the log probability, and return the sampled value.
196204
197205
Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
198206
"""
199207
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
200-
return dot_tilde(ctx, sampler, right, left, vn, inds, vi)
208+
value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
209+
acclogp!(vi, logp)
210+
return value
201211
end
202212

203213

@@ -367,24 +377,30 @@ end
367377
"""
368378
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
369379
370-
This method is applied in the generated code for vectorized observed variables, e.g., `x .~
371-
MvNormal()` where `x` does occur the model inputs.
380+
Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
381+
accumulate the log probability, and return the observed value.
372382
373383
Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
374384
name and indices; if needed, these can be accessed through this function, though.
375385
"""
376386
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
377-
return dot_tilde(ctx, sampler, right, left, vi)
387+
logp = dot_tilde(ctx, sampler, right, left, vi)
388+
acclogp!(vi, logp)
389+
return left
378390
end
379391

380392
"""
381393
dot_tilde_observe(ctx, sampler, right, left, vi)
382394
383-
This method is applied in the generated code for vectorized observed constants, e.g., `[1.0] .~
384-
MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
395+
Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
396+
probability, and return the observed value.
397+
398+
Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
385399
"""
386400
function dot_tilde_observe(ctx, sampler, right, left, vi)
387-
return dot_tilde(ctx, sampler, right, left, vi)
401+
logp = dot_tilde(ctx, sampler, right, left, vi)
402+
acclogp!(vi, logp)
403+
return left
388404
end
389405

390406

0 commit comments

Comments
 (0)