Skip to content

Commit 050d719

Browse files
committed
Fix tests
1 parent 6866091 commit 050d719

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

src/debug_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ function check_model_and_trace(
426426
issuccess = check_model_pre_evaluation(model)
427427

428428
# Force single-threaded execution.
429-
DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
429+
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
430430

431431
# Perform checks after evaluating the model.
432432
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))

src/pointwise_logdensities.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,23 +230,36 @@ function pointwise_logdensities(
230230
# Get the data by executing the model once
231231
vi = VarInfo(model)
232232

233+
# This accumulator tracks the pointwise log-probabilities in a single iteration.
233234
AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType}
234235
vi = setaccs!!(vi, (AccType(),))
235236

236237
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
238+
239+
# Maintain a separate accumulator that isn't tied to a VarInfo but rather
240+
# tracks _all_ iterations.
241+
all_logps = AccType()
237242
for (sample_idx, chain_idx) in iters
238243
# Update the values
239244
setval!(vi, chain, sample_idx, chain_idx)
240245

241246
# Execute model
242-
vi = last(evaluate!!(model, vi))
247+
vi = setaccs!!(vi, (AccType(),))
248+
vi = last(_evaluate!!(model, vi))
249+
250+
# Get the log-probabilities
251+
this_iter_logps = getacc(vi, Val(accumulator_name(AccType))).logps
252+
253+
# Merge into main acc
254+
for (varname, this_lp) in this_iter_logps
255+
push!(all_logps, varname, only(this_lp))
256+
end
243257
end
244258

245-
logps = getacc(vi, Val(accumulator_name(AccType))).logps
246259
niters = size(chain, 1)
247260
nchains = size(chain, 3)
248261
logdensities = OrderedDict(
249-
varname => reshape(vals, niters, nchains) for (varname, vals) in logps
262+
varname => reshape(vals, niters, nchains) for (varname, vals) in all_logps.logps
250263
)
251264
return logdensities
252265
end

test/accumulators.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ using DynamicPPL:
2424
LogPriorAccumulator() ==
2525
LogPriorAccumulator{Float64}() ==
2626
LogPriorAccumulator{Float64}(0.0) ==
27-
zero(LogPriorAccumulator(1.0))
27+
DynamicPPL.reset(LogPriorAccumulator(1.0))
2828
@test LogLikelihoodAccumulator(0.0) ==
2929
LogLikelihoodAccumulator() ==
3030
LogLikelihoodAccumulator{Float64}() ==
3131
LogLikelihoodAccumulator{Float64}(0.0) ==
32-
zero(LogLikelihoodAccumulator(1.0))
32+
DynamicPPL.reset(LogLikelihoodAccumulator(1.0))
3333
end
3434

3535
@testset "addition and incrementation" begin
@@ -136,7 +136,7 @@ using DynamicPPL:
136136
@testset "map_accumulator(s)!!" begin
137137
# map over all accumulators
138138
accs = AccumulatorTuple(lp_f32, ll_f32)
139-
@test map(zero, accs) == AccumulatorTuple(
139+
@test map(DynamicPPL.reset, accs) == AccumulatorTuple(
140140
LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0)
141141
)
142142
# Test that the original wasn't modified.
@@ -147,7 +147,7 @@ using DynamicPPL:
147147
AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0))
148148

149149
# only apply to a particular accumulator
150-
@test map_accumulator(zero, accs, Val(:LogLikelihood)) ==
150+
@test map_accumulator(DynamicPPL.reset, accs, Val(:LogLikelihood)) ==
151151
AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0))
152152
@test map_accumulator(
153153
acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood)

0 commit comments

Comments
 (0)