@@ -488,7 +488,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
488
488
489
489
# Construct a chain with 'sampled values' of β
490
490
ground_truth_β = 2
491
- β_chain = MCMCChains. Chains (rand (Normal (ground_truth_β, 0.002 ), 1000 ), [:β ])
491
+ β_chain = MCMCChains. Chains (
492
+ rand (Normal (ground_truth_β, 0.002 ), 1000 ),
493
+ [:β ];
494
+ info= (; varname_to_symbol= Dict (@varname (β) => :β )),
495
+ )
492
496
493
497
# Generate predictions from that chain
494
498
xs_test = [10 + 0.1 , 10 + 2 * 0.1 ]
@@ -534,7 +538,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
534
538
@testset " prediction from multiple chains" begin
535
539
# Normal linreg model
536
540
multiple_β_chain = MCMCChains. Chains (
537
- reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
541
+ reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ),
542
+ [:β ];
543
+ info= (; varname_to_symbol= Dict (@varname (β) => :β )),
538
544
)
539
545
predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
540
546
@test size (multiple_β_chain, 3 ) == size (predictions, 3 )
0 commit comments