Skip to content

Commit 50d056d

Browse files
michaelosthegericardoV94
authored andcommitted
Assert warmup/posterior lengths in StepMethodTesteer
1 parent 55079d1 commit 50d056d

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def check_stat_dtype(self, step, idata):
133133
continue
134134
assert idata.sample_stats[stat].dtype == np.dtype(dtype)
135135

136-
def step_continuous(self, step_fn, draws):
136+
def step_continuous(self, step_fn, draws, chains=1, tune=1000):
137137
start, model, (mu, C) = mv_simple()
138138
unc = np.diag(C) ** 0.5
139139
check = (("x", np.mean, mu, unc / 10), ("x", np.std, unc, unc / 10))
@@ -143,14 +143,19 @@ def step_continuous(self, step_fn, draws):
143143
with warnings.catch_warnings():
144144
warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning)
145145
idata = pm.sample(
146-
tune=1000,
146+
tune=tune,
147147
draws=draws,
148-
chains=1,
148+
chains=chains,
149149
step=step,
150150
initvals=start,
151151
model=model,
152152
random_seed=1,
153+
discard_tuned_samples=False,
153154
)
155+
assert idata.warmup_posterior.sizes["chain"] == chains
156+
assert idata.warmup_posterior.sizes["draw"] == tune
157+
assert idata.posterior.sizes["chain"] == chains
158+
assert idata.posterior.sizes["draw"] == draws
154159
self.check_stat(check, idata, step.__class__.__name__)
155160
self.check_stat_dtype(idata, step)
156161

0 commit comments

Comments
 (0)