@@ -132,37 +132,49 @@ end
132
132
function tilde_assume (context:: PointwiseLogdensityContext , right, vn, vi)
133
133
# @info "PointwiseLogdensityContext tilde_assume!! called for $vn"
134
134
value, logp, vi = tilde_assume (context. context, right, vn, vi)
135
- # sym = DynamicPPL.getsym(vn)
136
- new_context = acc_logp! (context, vn, logp)
135
+ push! (context, vn, logp)
137
136
return value, logp, vi
138
137
end
139
138
140
- function dot_tilde_assume (context:: PointwiseLogdensityContext , right, left, vn, vi)
141
- # @info "PointwiseLogdensityContext dot_tilde_assume!! called for $vn"
142
- # @show vn, left, right, typeof(context).name
143
- value, logp, vi = dot_tilde_assume (context. context, right, left, vn, vi)
144
- new_context = acc_logp! (context, vn, logp)
139
+ function dot_tilde_assume (context:: PointwiseLogdensityContext , right, left, vns, vi)
140
+ # @info "PointwiseLogdensityContext dot_tilde_assume called for $vns"
141
+ value, logp, vi_new = dot_tilde_assume (context. context, right, left, vns, vi)
142
+ # dispatch recording of log-densities based on type of right
143
+ logps = record_dot_tilde_assume (context, right, left, vns, vi, logp)
144
+ sum (logps) ≈ logp || error (" Expected sum of individual logp equal origina, but differed sum($(join (logps, " ," )) ) != $logp_orig " )
145
145
return value, logp, vi
146
146
end
147
147
148
- function acc_logp! (context:: PointwiseLogdensityContext , vn:: VarName , logp)
149
- push! (context, vn, logp)
150
- return (context)
148
+ function record_dot_tilde_assume (context:: PointwiseLogdensityContext , right:: UnivariateDistribution , left, vns, vi, logp)
149
+ # forward to tilde_assume for each variable
150
+ map (vns) do vn
151
+ value_i, logp_i, vi_i = tilde_assume (context, right, vn, vi)
152
+ logp_i
153
+ end
151
154
end
152
155
153
- function acc_logp! (context:: PointwiseLogdensityContext , vns :: AbstractVector{<:VarName} , logp)
154
- # construct a new VarName from given sequence of VarName
155
- # assume that all items in vns have an IndexLens optic
156
- indices = tuplejoin ( map (vn -> getoptic (vn) . indices, vns) ... )
157
- vn = VarName ( first (vns), Accessors . IndexLens (indices) )
158
- push! (context, vn, logp)
159
- return (context)
156
+ function record_dot_tilde_assume (context:: PointwiseLogdensityContext , rights :: AbstractVector{<:Distribution} , left, vns, vi , logp)
157
+ # forward to tilde_assume for each variable and distribution
158
+ logps = map ( vns, rights) do vn, right
159
+ # use current context to record vn
160
+ value_i, logp_i, vi_i = tilde_assume (context, right, vn, vi )
161
+ logp_i
162
+ end
160
163
end
161
164
162
- # https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/8
163
- @inline tuplejoin (x) = x
164
- @inline tuplejoin (x, y) = (x... , y... )
165
- @inline tuplejoin (x, y, z... ) = (x... , tuplejoin (y, z... )... )
165
+ function record_dot_tilde_assume (context:: PointwiseLogdensityContext , right:: MultivariateDistribution , left, vns, vi, logp)
166
+ # @info "PointwiseLogdensityContext record_dot_tilde_assume multivariate called for $vns"
167
+ # For multivariate distribution on the right there is only a single density.
168
+ # Need to construct a combined VarName.
169
+ # Assume that all vns have an IndexLens with a Colon at the first position
170
+ # and a single number at the second position.
171
+ indices = map (vn -> getoptic (vn). indices[2 ], vns)
172
+ indices_combined = (:,indices)
173
+ # indices = tuplejoin(map(vn -> getoptic(vn).indices[2], vns)...)
174
+ vn = VarName (first (vns), Accessors. IndexLens (indices_combined))
175
+ push! (context, vn, logp)
176
+ return logp
177
+ end
166
178
167
179
() -> begin
168
180
# code that generates julia-repl in docstring below
0 commit comments