@@ -73,11 +73,26 @@ function Base.push!(
73
73
return context. logdensities[vn] = logp
74
74
end
75
75
76
+
77
+ function _include_prior (context:: PointwiseLogdensityContext )
78
+ return leafcontext (context) isa Union{PriorContext,DefaultContext}
79
+ end
80
+ function _include_likelihood (context:: PointwiseLogdensityContext )
81
+ return leafcontext (context) isa Union{LikelihoodContext,DefaultContext}
82
+ end
83
+
84
+
85
+
76
86
function tilde_observe!! (context:: PointwiseLogdensityContext , right, left, vi)
77
87
# Defer literal `observe` to child-context.
78
88
return tilde_observe!! (context. context, right, left, vi)
79
89
end
80
90
function tilde_observe!! (context:: PointwiseLogdensityContext , right, left, vn, vi)
91
+ # Completely defer to child context if we are not tracking likelihoods.
92
+ if ! (_include_likelihood (context))
93
+ return tilde_observe!! (context. context, right, left, vn, vi)
94
+ end
95
+
81
96
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
82
97
# we have to intercept the call to `tilde_observe!`.
83
98
logp, vi = tilde_observe (context. context, right, left, vi)
@@ -93,6 +108,11 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v
93
108
return dot_tilde_observe!! (context. context, right, left, vi)
94
109
end
95
110
function dot_tilde_observe!! (context:: PointwiseLogdensityContext , right, left, vn, vi)
111
+ # Completely defer to child context if we are not tracking likelihoods.
112
+ if ! (_include_likelihood (context))
113
+ return dot_tilde_observe!! (context. context, right, left, vn, vi)
114
+ end
115
+
96
116
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
97
117
# we have to intercept the call to `dot_tilde_observe!`.
98
118
@@ -132,16 +152,19 @@ end
132
152
function tilde_assume (context:: PointwiseLogdensityContext , right:: Distribution , vn, vi)
133
153
# @info "PointwiseLogdensityContext tilde_assume called for $vn"
134
154
value, logp, vi = tilde_assume (context. context, right, vn, vi)
135
- push! (context, vn, logp)
155
+ if _include_prior (context)
156
+ push! (context, vn, logp)
157
+ end
136
158
return value, logp, vi
137
159
end
138
160
139
161
function dot_tilde_assume (context:: PointwiseLogdensityContext , right, left, vns, vi)
140
162
# @info "PointwiseLogdensityContext dot_tilde_assume called for $vns"
141
163
value, logp, vi_new = dot_tilde_assume (context. context, right, left, vns, vi)
142
- # dispatch recording of log-densities based on type of right
143
- logps = record_dot_tilde_assume (context, right, left, vns, vi, logp)
144
- sum (logps) ≈ logp || error (" Expected sum of individual logp equal origina, but differed sum($(join (logps, " ," )) ) != $logp_orig " )
164
+ if _include_prior (context)
165
+ logps = record_dot_tilde_assume (context, right, left, vns, vi, logp)
166
+ sum (logps) ≈ logp || error (" Expected sum of individual logp equal origina, but differed sum($(join (logps, " ," )) ) != $logp_orig " )
167
+ end
145
168
return value, logp, vi
146
169
end
147
170
172
195
pointwise_logdensities(model::Model, chain::Chains, keytype = String)
173
196
174
197
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
175
- with keys corresponding to symbols of the observations , and values being matrices
198
+ with keys corresponding to symbols of the variables , and values being matrices
176
199
of shape `(num_chains, num_samples)`.
177
200
178
201
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
0 commit comments