Skip to content

Commit 40e2b94

Browse files
committed
Add tune, draws, random_seed to kwargs
1 parent b020736 commit 40e2b94

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tests/sampling/test_mcmc_external.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,16 @@ def test_step_args():
9090

9191
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
9292
def test_sample_var_names(nuts_sampler):
93+
seed = 1234
9394
kwargs = {
9495
"nuts_sampler": nuts_sampler,
9596
"chains": 1,
97+
"tune": 100,
98+
"draws": 100,
99+
"random_seed": seed,
96100
}
97101

98102
# Generate data
99-
seed = 1234
100103
rng = np.random.default_rng(seed)
101104

102105
group = rng.choice(list("ABCD"), size=100)
@@ -117,10 +120,8 @@ def test_sample_var_names(nuts_sampler):
117120

118121
# Sample with and without var_names, but always with the same seed
119122
with model:
120-
idata_1 = sample(tune=100, draws=100, random_seed=seed, **kwargs)
121-
idata_2 = sample(
122-
tune=100, draws=100, var_names=["b_group", "b_x", "sigma"], random_seed=seed, **kwargs
123-
)
123+
idata_1 = sample(**kwargs)
124+
idata_2 = sample(var_names=["b_group", "b_x", "sigma"], **kwargs)
124125

125126
assert "mu" in idata_1.posterior
126127
assert "mu" not in idata_2.posterior

0 commit comments

Comments
 (0)