@@ -495,7 +495,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
495
495
496
496
# Construct a chain with 'sampled values' of β
497
497
ground_truth_β = 2
498
- β_chain = MCMCChains. Chains (rand (Normal (ground_truth_β, 0.002 ), 1000 ), [:β ])
498
+ β_chain = MCMCChains. Chains (
499
+ rand (Normal (ground_truth_β, 0.002 ), 1000 ),
500
+ [:β ];
501
+ info= (; varname_to_symbol= Dict (@varname (β) => :β )),
502
+ )
499
503
500
504
# Generate predictions from that chain
501
505
xs_test = [10 + 0.1 , 10 + 2 * 0.1 ]
@@ -541,7 +545,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
541
545
@testset " prediction from multiple chains" begin
542
546
# Normal linreg model
543
547
multiple_β_chain = MCMCChains. Chains (
544
- reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
548
+ reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ),
549
+ [:β ];
550
+ info= (; varname_to_symbol= Dict (@varname (β) => :β )),
545
551
)
546
552
predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
547
553
@test size (multiple_β_chain, 3 ) == size (predictions, 3 )
0 commit comments