|
38 | 38 | """
|
39 | 39 | tilde_assume(ctx, sampler, right, vn, inds, vi)
|
40 | 40 |
|
41 |
| -This method is applied in the generated code for assumed variables, e.g., `x ~ Normal()` where |
42 |
| -`x` does not occur in the model inputs. |
| 41 | +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), |
| 42 | +accumulate the log probability, and return the sampled value. |
43 | 43 |
|
44 | 44 | Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
|
45 | 45 | """
|
46 | 46 | function tilde_assume(ctx, sampler, right, vn, inds, vi)
|
47 |
| - return tilde(ctx, sampler, right, vn, inds, vi) |
| 47 | + value, logp = tilde(ctx, sampler, right, vn, inds, vi) |
| 48 | + acclogp!(vi, logp) |
| 49 | + return value |
48 | 50 | end
|
49 | 51 |
|
50 | 52 |
|
|
72 | 74 | """
|
73 | 75 | tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
|
74 | 76 |
|
75 |
| -This method is applied in the generated code for observed variables, e.g., `x ~ Normal()` where |
76 |
| -`x` does occur in the model inputs. |
| 77 | +Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), |
| 78 | +accumulate the log probability, and return the observed value. |
77 | 79 |
|
78 |
| -Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable |
79 |
| -name and indices; if needed, these can be accessed through this function, though. |
| 80 | +Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name |
| 81 | +and indices; if needed, these can be accessed through this function, though. |
80 | 82 | """
|
81 | 83 | function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
|
82 |
| - return tilde(ctx, sampler, right, left, vi) |
| 84 | + logp = tilde(ctx, sampler, right, left, vi) |
| 85 | + acclogp!(vi, logp) |
| 86 | + return left |
83 | 87 | end
|
84 | 88 |
|
85 | 89 | """
|
86 | 90 | tilde_observe(ctx, sampler, right, left, vi)
|
87 | 91 |
|
88 |
| -This method is applied in the generated code for observed constants, e.g., `1.0 ~ Normal()`. |
| 92 | +Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the |
| 93 | +observed value. |
| 94 | +
|
89 | 95 | Falls back to `tilde(ctx, sampler, right, left, vi)`.
|
90 | 96 | """
|
91 | 97 | function tilde_observe(ctx, sampler, right, left, vi)
|
92 |
| - return tilde(ctx, sampler, right, left, vi) |
| 98 | + logp = tilde(ctx, sampler, right, left, vi) |
| 99 | + acclogp!(vi, logp) |
| 100 | + return left |
93 | 101 | end
|
94 | 102 |
|
95 | 103 |
|
@@ -191,13 +199,15 @@ end
|
191 | 199 | """
|
192 | 200 | dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
|
193 | 201 |
|
194 |
| -This method is applied in the generated code for assumed vectorized variables, e.g., `x .~ |
195 |
| -MvNormal()` where `x` does not occur in the model inputs. |
| 202 | +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the |
| 203 | +model inputs), accumulate the log probability, and return the sampled value. |
196 | 204 |
|
197 | 205 | Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
|
198 | 206 | """
|
199 | 207 | function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
|
200 |
| - return dot_tilde(ctx, sampler, right, left, vn, inds, vi) |
| 208 | + value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi) |
| 209 | + acclogp!(vi, logp) |
| 210 | + return value |
201 | 211 | end
|
202 | 212 |
|
203 | 213 |
|
@@ -367,24 +377,30 @@ end
|
367 | 377 | """
|
368 | 378 | dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
|
369 | 379 |
|
370 |
| -This method is applied in the generated code for vectorized observed variables, e.g., `x .~ |
371 |
| -MvNormal()` where `x` does occur the model inputs. |
| 380 | +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), |
| 381 | +accumulate the log probability, and return the observed value. |
372 | 382 |
|
373 | 383 | Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
|
374 | 384 | name and indices; if needed, these can be accessed through this function, though.
|
375 | 385 | """
|
376 | 386 | function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
|
377 |
| - return dot_tilde(ctx, sampler, right, left, vi) |
| 387 | + logp = dot_tilde(ctx, sampler, right, left, vi) |
| 388 | + acclogp!(vi, logp) |
| 389 | + return left |
378 | 390 | end
|
379 | 391 |
|
380 | 392 | """
|
381 | 393 | dot_tilde_observe(ctx, sampler, right, left, vi)
|
382 | 394 |
|
383 |
| -This method is applied in the generated code for vectorized observed constants, e.g., `[1.0] .~ |
384 |
| -MvNormal()`. Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. |
| 395 | +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log |
| 396 | +probability, and return the observed value. |
| 397 | +
|
| 398 | +Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. |
385 | 399 | """
|
386 | 400 | function dot_tilde_observe(ctx, sampler, right, left, vi)
|
387 |
| - return dot_tilde(ctx, sampler, right, left, vi) |
| 401 | + logp = dot_tilde(ctx, sampler, right, left, vi) |
| 402 | + acclogp!(vi, logp) |
| 403 | + return left |
388 | 404 | end
|
389 | 405 |
|
390 | 406 |
|
|
0 commit comments