@@ -77,49 +77,11 @@ function tilde_assume(
7777 return tilde_assume (rng, childcontext (context), args... )
7878end
7979
80- function tilde_assume (context:: PriorContext{<:NamedTuple} , right, vn, vi)
81- if haskey (context. vars, getsym (vn))
82- vi = setindex!! (vi, tovec (get (context. vars, vn)), vn)
83- settrans!! (vi, false , vn)
84- end
85- return tilde_assume (PriorContext (), right, vn, vi)
86- end
87- function tilde_assume (
88- rng:: Random.AbstractRNG , context:: PriorContext{<:NamedTuple} , sampler, right, vn, vi
89- )
90- if haskey (context. vars, getsym (vn))
91- vi = setindex!! (vi, tovec (get (context. vars, vn)), vn)
92- settrans!! (vi, false , vn)
93- end
94- return tilde_assume (rng, PriorContext (), sampler, right, vn, vi)
95- end
96-
97- function tilde_assume (context:: LikelihoodContext{<:NamedTuple} , right, vn, vi)
98- if haskey (context. vars, getsym (vn))
99- vi = setindex!! (vi, tovec (get (context. vars, vn)), vn)
100- settrans!! (vi, false , vn)
101- end
102- return tilde_assume (LikelihoodContext (), right, vn, vi)
103- end
104- function tilde_assume (
105- rng:: Random.AbstractRNG ,
106- context:: LikelihoodContext{<:NamedTuple} ,
107- sampler,
108- right,
109- vn,
110- vi,
111- )
112- if haskey (context. vars, getsym (vn))
113- vi = setindex!! (vi, tovec (get (context. vars, vn)), vn)
114- settrans!! (vi, false , vn)
115- end
116- return tilde_assume (rng, LikelihoodContext (), sampler, right, vn, vi)
117- end
11880function tilde_assume (:: LikelihoodContext , right, vn, vi)
119- return assume (NoDist (right), vn, vi)
81+ return assume (nodist (right), vn, vi)
12082end
12183function tilde_assume (rng:: Random.AbstractRNG , :: LikelihoodContext , sampler, right, vn, vi)
122- return assume (rng, sampler, NoDist (right), vn, vi)
84+ return assume (rng, sampler, nodist (right), vn, vi)
12385end
12486
12587function tilde_assume (context:: PrefixContext , right, vn, vi)
@@ -257,7 +219,7 @@ function assume(
257219 else
258220 r = init (rng, dist, sampler)
259221 if istrans (vi)
260- f = to_linked_internal_transform (vi, dist)
222+ f = to_linked_internal_transform (vi, vn, dist)
261223 push!! (vi, vn, f (r), dist, sampler)
262224 # By default `push!!` sets the transformed flag to `false`.
263225 settrans!! (vi, true , vn)
@@ -328,37 +290,6 @@ function dot_tilde_assume(
328290end
329291
330292# `LikelihoodContext`
331- function dot_tilde_assume (context:: LikelihoodContext{<:NamedTuple} , right, left, vn, vi)
332- return if haskey (context. vars, getsym (vn))
333- var = get (context. vars, vn)
334- _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
335- set_val! (vi, _vns, _right, _left)
336- settrans!! .((vi,), false , _vns)
337- dot_tilde_assume (LikelihoodContext (), _right, _left, _vns, vi)
338- else
339- dot_tilde_assume (LikelihoodContext (), right, left, vn, vi)
340- end
341- end
342- function dot_tilde_assume (
343- rng:: Random.AbstractRNG ,
344- context:: LikelihoodContext{<:NamedTuple} ,
345- sampler,
346- right,
347- left,
348- vn,
349- vi,
350- )
351- return if haskey (context. vars, getsym (vn))
352- var = get (context. vars, vn)
353- _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
354- set_val! (vi, _vns, _right, _left)
355- settrans!! .((vi,), false , _vns)
356- dot_tilde_assume (rng, LikelihoodContext (), sampler, _right, _left, _vns, vi)
357- else
358- dot_tilde_assume (rng, LikelihoodContext (), sampler, right, left, vn, vi)
359- end
360- end
361-
362293function dot_tilde_assume (context:: LikelihoodContext , right, left, vn, vi)
363294 return dot_assume (nodist (right), left, vn, vi)
364295end
@@ -368,46 +299,16 @@ function dot_tilde_assume(
368299 return dot_assume (rng, sampler, nodist (right), vn, left, vi)
369300end
370301
371- # `PriorContext`
372- function dot_tilde_assume (context:: PriorContext{<:NamedTuple} , right, left, vn, vi)
373- return if haskey (context. vars, getsym (vn))
374- var = get (context. vars, vn)
375- _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
376- set_val! (vi, _vns, _right, _left)
377- settrans!! .((vi,), false , _vns)
378- dot_tilde_assume (PriorContext (), _right, _left, _vns, vi)
379- else
380- dot_tilde_assume (PriorContext (), right, left, vn, vi)
381- end
382- end
383- function dot_tilde_assume (
384- rng:: Random.AbstractRNG ,
385- context:: PriorContext{<:NamedTuple} ,
386- sampler,
387- right,
388- left,
389- vn,
390- vi,
391- )
392- return if haskey (context. vars, getsym (vn))
393- var = get (context. vars, vn)
394- _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
395- set_val! (vi, _vns, _right, _left)
396- settrans!! .((vi,), false , _vns)
397- dot_tilde_assume (rng, PriorContext (), sampler, _right, _left, _vns, vi)
398- else
399- dot_tilde_assume (rng, PriorContext (), sampler, right, left, vn, vi)
400- end
401- end
402-
403302# `PrefixContext`
404303function dot_tilde_assume (context:: PrefixContext , right, left, vn, vi)
405- return dot_tilde_assume (context. context, right, prefix .(Ref (context), vn), vi)
304+ return dot_tilde_assume (context. context, right, left, prefix .(Ref (context), vn), vi)
406305end
407306
408- function dot_tilde_assume (rng, context:: PrefixContext , sampler, right, left, vn, vi)
307+ function dot_tilde_assume (
308+ rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, left, vn, vi
309+ )
409310 return dot_tilde_assume (
410- rng, context. context, sampler, right, prefix .(Ref (context), vn), vi
311+ rng, context. context, sampler, right, left, prefix .(Ref (context), vn), vi
411312 )
412313end
413314
500401# HACK: These methods are only used in the `get_and_set_val!` methods below.
501402# FIXME : Remove these.
502403function _link_broadcast_new (vi, vn, dist, r)
503- b = to_linked_internal_transform (vi, dist)
404+ b = to_linked_internal_transform (vi, vn, dist)
504405 return b (r)
505406end
506407
@@ -591,7 +492,7 @@ function get_and_set_val!(
591492 push!! .((vi,), vns, _link_broadcast_new .((vi,), vns, dists, r), dists, (spl,))
592493 # NOTE: Need to add the correction.
593494 # FIXME : This is not great.
594- acclogp_assume !! (vi, sum (logabsdetjac .(link_transform .(dists), r)))
495+ acclogp !! (vi, sum (logabsdetjac .(link_transform .(dists), r)))
595496 # `push!!` sets the trans-flag to `false` by default.
596497 settrans!! .((vi,), true , vns)
597498 else
0 commit comments