Skip to content

Commit a425c41

Browse files
committed
fix test error by discard burn-in's
1 parent c7d08b0 commit a425c41

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

test/ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ end
3535
# Infer
3636
m_lin_reg = linear_reg(xs_train, ys_train)
3737
chain_lin_reg = sample(
38-
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)),
38+
DynamicPPL.LogDensityFunction(m_lin_reg),
3939
AdvancedHMC.NUTS(0.65),
40-
200;
40+
1000;
4141
chain_type=MCMCChains.Chains,
4242
param_names=[],
43+
discard_initial=100,
44+
n_adapt=100,
4345
)
4446

4547
# Predict on two last indices
@@ -156,9 +158,11 @@ end
156158
chain = sample(
157159
DynamicPPL.LogDensityFunction(m, DynamicPPL.VarInfo(m)),
158160
AdvancedHMC.NUTS(0.65),
159-
100;
161+
1000;
160162
chain_type=MCMCChains.Chains,
161163
param_names=param_names[model],
164+
discard_initial=100,
165+
n_adapt=100,
162166
)
163167
chain_predict = DynamicPPL.predict(model(x, missing), chain)
164168
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]

0 commit comments

Comments
 (0)