Skip to content

Commit cbbfb46

Browse files
committed
Add varname info to chain
1 parent 0f8804f commit cbbfb46

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

test/model.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
488488

489489
# Construct a chain with 'sampled values' of β
490490
ground_truth_β = 2
491-
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])
491+
β_chain = MCMCChains.Chains(
492+
rand(Normal(ground_truth_β, 0.002), 1000),
493+
[];
494+
info=(; varname_to_symbol=Dict(@varname(β) => )),
495+
)
492496

493497
# Generate predictions from that chain
494498
xs_test = [10 + 0.1, 10 + 2 * 0.1]
@@ -534,7 +538,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
534538
@testset "prediction from multiple chains" begin
535539
# Normal linreg model
536540
multiple_β_chain = MCMCChains.Chains(
537-
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
541+
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2),
542+
[];
543+
info=(; varname_to_symbol=Dict(@varname(β) => )),
538544
)
539545
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
540546
@test size(multiple_β_chain, 3) == size(predictions, 3)

0 commit comments

Comments
 (0)