Skip to content

Commit 6ca6a20

Browse files
committed
Add varname info to chain
1 parent a2159c5 commit 6ca6a20

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

test/model.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
495495

496496
# Construct a chain with 'sampled values' of β
497497
ground_truth_β = 2
498-
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])
498+
β_chain = MCMCChains.Chains(
499+
rand(Normal(ground_truth_β, 0.002), 1000),
500+
[];
501+
info=(; varname_to_symbol=Dict(@varname(β) => )),
502+
)
499503

500504
# Generate predictions from that chain
501505
xs_test = [10 + 0.1, 10 + 2 * 0.1]
@@ -541,7 +545,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
541545
@testset "prediction from multiple chains" begin
542546
# Normal linreg model
543547
multiple_β_chain = MCMCChains.Chains(
544-
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
548+
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2),
549+
[];
550+
info=(; varname_to_symbol=Dict(@varname(β) => )),
545551
)
546552
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
547553
@test size(multiple_β_chain, 3) == size(predictions, 3)

0 commit comments

Comments
 (0)