|
112 | 112 | test_base(SimpleVarInfo(DynamicPPL.VarNamedVector()))
|
113 | 113 | end
|
114 | 114 |
|
115 |
| - @testset "get/set/acc/resetlogp" begin |
| 115 | + @testset "get/set/acclogp" begin |
116 | 116 | function test_varinfo_logp!(vi)
|
117 | 117 | @test DynamicPPL.getlogjoint(vi) === 0.0
|
118 | 118 | vi = DynamicPPL.setlogprior!!(vi, 1.0)
|
|
131 | 131 | @test DynamicPPL.getlogprior(vi) === 2.0
|
132 | 132 | @test DynamicPPL.getloglikelihood(vi) === 2.0
|
133 | 133 | @test DynamicPPL.getlogjoint(vi) === 4.0
|
134 |
| - vi = DynamicPPL.resetlogp!!(vi) |
135 |
| - @test DynamicPPL.getlogjoint(vi) === 0.0 |
136 | 134 | end
|
137 | 135 |
|
138 | 136 | vi = VarInfo()
|
|
143 | 141 | test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector()))
|
144 | 142 | end
|
145 | 143 |
|
146 |
| - @testset "accumulators" begin |
| 144 | + @testset "logp accumulators" begin |
147 | 145 | @model function demo()
|
148 | 146 | a ~ Normal()
|
149 | 147 | b ~ Normal()
|
|
227 | 225 | @test_throws r"has no field `?LogPrior" getlogjoint(vi)
|
228 | 226 | end
|
229 | 227 |
|
| 228 | + @testset "resetaccs" begin |
| 229 | + # Put in a bunch of accumulators, check that they're all reset either |
| 230 | + # when we call resetaccs!!, empty!!, or evaluate!!. |
| 231 | + @model function demo() |
| 232 | + a ~ Normal() |
| 233 | + return x ~ Normal(a) |
| 234 | + end |
| 235 | + model = demo() |
| 236 | + vi_orig = VarInfo(model) |
| 237 | + # It already has the logp accumulators, so let's add in some more. |
| 238 | + vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.DebugUtils.DebugAccumulator(true)) |
| 239 | + vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.ValuesAsInModelAccumulator(true)) |
| 240 | + vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.PriorDistributionAccumulator()) |
| 241 | + vi_orig = DynamicPPL.setacc!!( |
| 242 | + vi_orig, DynamicPPL.PointwiseLogProbAccumulator{:both}() |
| 243 | + ) |
| 244 | + # And evaluate the model once so that they are populated. |
| 245 | + _, vi_orig = DynamicPPL.evaluate!!(model, vi_orig) |
| 246 | + |
| 247 | + function all_accs_empty(vi::AbstractVarInfo) |
| 248 | + for acc_key in keys(DynamicPPL.getaccs(vi)) |
| 249 | + acc = DynamicPPL.getacc(vi, Val(acc_key)) |
| 250 | + acc == DynamicPPL.reset(acc) || return false |
| 251 | + end |
| 252 | + return true |
| 253 | + end |
| 254 | + |
| 255 | + @test !all_accs_empty(vi_orig) |
| 256 | + |
| 257 | + vi = DynamicPPL.resetaccs!!(deepcopy(vi_orig)) |
| 258 | + @test all_accs_empty(vi) |
| 259 | + @test getlogjoint(vi) == 0.0 # for good measure |
| 260 | + @test getlogprior(vi) == 0.0 |
| 261 | + @test getloglikelihood(vi) == 0.0 |
| 262 | + |
| 263 | + vi = DynamicPPL.empty!!(deepcopy(vi_orig)) |
| 264 | + @test all_accs_empty(vi) |
| 265 | + @test getlogjoint(vi) == 0.0 |
| 266 | + @test getlogprior(vi) == 0.0 |
| 267 | + @test getloglikelihood(vi) == 0.0 |
| 268 | + |
| 269 | + function all_accs_same(vi1::AbstractVarInfo, vi2::AbstractVarInfo) |
| 270 | + # Check that they have the same accs |
| 271 | + keys1 = Set(keys(DynamicPPL.getaccs(vi1))) |
| 272 | + keys2 = Set(keys(DynamicPPL.getaccs(vi2))) |
| 273 | + keys1 == keys2 || return false |
| 274 | + # Check that they have the same values |
| 275 | + for acc_key in keys1 |
| 276 | + acc1 = DynamicPPL.getacc(vi1, Val(acc_key)) |
| 277 | + acc2 = DynamicPPL.getacc(vi2, Val(acc_key)) |
| 278 | + if acc1 != acc2 |
| 279 | + @show acc1, acc2 |
| 280 | + end |
| 281 | + acc1 == acc2 || return false |
| 282 | + end |
| 283 | + return true |
| 284 | + end |
| 285 | + # Hopefully this doesn't matter |
| 286 | + @test all_accs_same(vi_orig, deepcopy(vi_orig)) |
| 287 | + # If we re-evaluate, then we expect the accs to be reset prior to evaluation. |
| 288 | + # Thus after re-evaluation, the accs should be exactly the same as before. |
| 289 | + _, vi = DynamicPPL.evaluate!!(model, deepcopy(vi_orig)) |
| 290 | + @test all_accs_same(vi, vi_orig) |
| 291 | + end |
| 292 | + |
230 | 293 | @testset "flags" begin
|
231 | 294 | # Test flag setting:
|
232 | 295 | # is_flagged, set_flag!, unset_flag!
|
|
0 commit comments