@@ -90,12 +90,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
90
90
samples = (; samples_dict... )
91
91
samples = modify_value_representation (samples) # `modify_value_representation` defined in test/test_util.jl
92
92
@test logpriors[i] ≈
93
- DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
93
+ DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
94
94
@test loglikelihoods[i] ≈ DynamicPPL. TestUtils. loglikelihood_true (
95
95
model, samples[:s ], samples[:m ]
96
96
)
97
97
@test logjoints[i] ≈
98
- DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
98
+ DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
99
99
end
100
100
end
101
101
end
@@ -283,10 +283,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
283
283
# Ensure log-probability computations are implemented.
284
284
@test logprior (model, x) ≈ DynamicPPL. TestUtils. logprior_true (model, x... )
285
285
@test loglikelihood (model, x) ≈
286
- DynamicPPL. TestUtils. loglikelihood_true (model, x... )
286
+ DynamicPPL. TestUtils. loglikelihood_true (model, x... )
287
287
@test logjoint (model, x) ≈ DynamicPPL. TestUtils. logjoint_true (model, x... )
288
288
@test logjoint (model, x) !=
289
- DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
289
+ DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
290
290
# Ensure `varnames` is implemented.
291
291
vi = last (
292
292
DynamicPPL. evaluate!! (
@@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
383
383
example_values = DynamicPPL. TestUtils. rand_prior_true (model)
384
384
varinfos = DynamicPPL. TestUtils. setup_varinfos (model, example_values, vns)
385
385
@testset " $(short_varinfo_name (varinfo)) " for varinfo in varinfos
386
- realizations = values_as_in_model (model, varinfo)
386
+ # We can set the include_colon_eq arg to false because none of
387
+ # the demo models contain :=. The behaviour when
388
+ # include_colon_eq is true is tested in test/compiler.jl
389
+ realizations = values_as_in_model (model, false , varinfo)
387
390
# Ensure that all variables are found.
388
391
vns_found = collect (keys (realizations))
389
392
@test vns ∩ vns_found == vns ∪ vns_found
@@ -432,72 +435,85 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
432
435
433
436
@testset " predict" begin
434
437
@testset " with MCMCChains.Chains" begin
435
- DynamicPPL. Random. seed! (100 )
436
-
437
438
@model function linear_reg (x, y, σ= 0.1 )
438
439
β ~ Normal (0 , 1 )
439
440
for i in eachindex (y)
440
441
y[i] ~ Normal (β * x[i], σ)
441
442
end
443
+ # Insert a := block to test that it is not included in predictions
444
+ σ2 := σ^ 2
442
445
end
443
446
444
- @model function linear_reg_vec (x, y, σ= 0.1 )
445
- β ~ Normal (0 , 1 )
446
- return y ~ MvNormal (β .* x, σ^ 2 * I)
447
- end
448
-
447
+ # Construct a chain with 'sampled values' of β
449
448
ground_truth_β = 2
450
449
β_chain = MCMCChains. Chains (rand (Normal (ground_truth_β, 0.002 ), 1000 ), [:β ])
451
450
451
+ # Generate predictions from that chain
452
452
xs_test = [10 + 0.1 , 10 + 2 * 0.1 ]
453
453
m_lin_reg_test = linear_reg (xs_test, fill (missing , length (xs_test)))
454
454
predictions = DynamicPPL. predict (m_lin_reg_test, β_chain)
455
455
456
- ys_pred = vec (mean (Array (group (predictions, :y )); dims= 1 ))
457
- @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
458
- @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
459
-
460
- # Ensure that `rng` is respected
461
- rng = MersenneTwister (42 )
462
- predictions1 = DynamicPPL. predict (rng, m_lin_reg_test, β_chain[1 : 2 ])
463
- predictions2 = DynamicPPL. predict (
464
- MersenneTwister (42 ), m_lin_reg_test, β_chain[1 : 2 ]
465
- )
466
- @test all (Array (predictions1) .== Array (predictions2))
467
-
468
- # Predict on two last indices for vectorized
469
- m_lin_reg_test = linear_reg_vec (xs_test, missing )
470
- predictions_vec = DynamicPPL. predict (m_lin_reg_test, β_chain)
471
- ys_pred_vec = vec (mean (Array (group (predictions_vec, :y )); dims= 1 ))
472
-
473
- @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
474
- @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
456
+ # Also test a vectorized model
457
+ @model function linear_reg_vec (x, y, σ= 0.1 )
458
+ β ~ Normal (0 , 1 )
459
+ return y ~ MvNormal (β .* x, σ^ 2 * I)
460
+ end
461
+ m_lin_reg_test_vec = linear_reg_vec (xs_test, missing )
475
462
476
- # Multiple chains
477
- multiple_β_chain = MCMCChains. Chains (
478
- reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
479
- )
480
- m_lin_reg_test = linear_reg (xs_test, fill (missing , length (xs_test)))
481
- predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
482
- @test size (multiple_β_chain, 3 ) == size (predictions, 3 )
463
+ @testset " variables in chain" begin
464
+ # Note that this also checks that variables on the lhs of :=,
465
+ # such as σ2, are not included in the resulting chain
466
+ @test Set (keys (predictions)) == Set ([Symbol (" y[1]" ), Symbol (" y[2]" )])
467
+ end
483
468
484
- for chain_idx in MCMCChains . chains (multiple_β_chain)
485
- ys_pred = vec (mean (Array (group (predictions[:, :, chain_idx] , :y )); dims= 1 ))
469
+ @testset " accuracy " begin
470
+ ys_pred = vec (mean (Array (group (predictions, :y )); dims= 1 ))
486
471
@test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
487
472
@test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
488
473
end
489
474
490
- # Predict on two last indices for vectorized
491
- m_lin_reg_test = linear_reg_vec (xs_test, missing )
492
- predictions_vec = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
493
-
494
- for chain_idx in MCMCChains. chains (multiple_β_chain)
495
- ys_pred_vec = vec (
496
- mean (Array (group (predictions_vec[:, :, chain_idx], :y )); dims= 1 )
475
+ @testset " ensure that rng is respected" begin
476
+ rng = MersenneTwister (42 )
477
+ predictions1 = DynamicPPL. predict (rng, m_lin_reg_test, β_chain[1 : 2 ])
478
+ predictions2 = DynamicPPL. predict (
479
+ MersenneTwister (42 ), m_lin_reg_test, β_chain[1 : 2 ]
497
480
)
481
+ @test all (Array (predictions1) .== Array (predictions2))
482
+ end
483
+
484
+ @testset " accuracy on vectorized model" begin
485
+ predictions_vec = DynamicPPL. predict (m_lin_reg_test_vec, β_chain)
486
+ ys_pred_vec = vec (mean (Array (group (predictions_vec, :y )); dims= 1 ))
487
+
498
488
@test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
499
489
@test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
500
490
end
491
+
492
+ @testset " prediction from multiple chains" begin
493
+ # Normal linreg model
494
+ multiple_β_chain = MCMCChains. Chains (
495
+ reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
496
+ )
497
+ predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
498
+ @test size (multiple_β_chain, 3 ) == size (predictions, 3 )
499
+
500
+ for chain_idx in MCMCChains. chains (multiple_β_chain)
501
+ ys_pred = vec (mean (Array (group (predictions[:, :, chain_idx], :y )); dims= 1 ))
502
+ @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
503
+ @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
504
+ end
505
+
506
+ # Vectorized linreg model
507
+ predictions_vec = DynamicPPL. predict (m_lin_reg_test_vec, multiple_β_chain)
508
+
509
+ for chain_idx in MCMCChains. chains (multiple_β_chain)
510
+ ys_pred_vec = vec (
511
+ mean (Array (group (predictions_vec[:, :, chain_idx], :y )); dims= 1 )
512
+ )
513
+ @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
514
+ @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
515
+ end
516
+ end
501
517
end
502
518
503
519
@testset " with AbstractVector{<:AbstractVarInfo}" begin
@@ -524,7 +540,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
524
540
525
541
@test size (predicted_vis) == size (chain)
526
542
@test Set (keys (predicted_vis[1 ])) ==
527
- Set ([@varname (β), @varname (y[1 ]), @varname (y[2 ])])
543
+ Set ([@varname (β), @varname (y[1 ]), @varname (y[2 ])])
528
544
# because β samples are from the prior, the std will be larger
529
545
@test mean ([
530
546
predicted_vis[i][@varname (y[1 ])] for i in eachindex (predicted_vis)
0 commit comments