Skip to content

Commit b020736

Browse files
committed
Explicitly importing objects from module
1 parent 81ec3ab commit b020736

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

tests/sampling/test_mcmc_external.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,20 @@
1616
import numpy.testing as npt
1717
import pytest
1818

19-
import pymc as pm
19+
from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample
2020

2121

2222
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
2323
def test_external_nuts_sampler(recwarn, nuts_sampler):
2424
if nuts_sampler != "pymc":
2525
pytest.importorskip(nuts_sampler)
2626

27-
with pm.Model():
28-
x = pm.Normal("x", 100, 5)
29-
y = pm.Data("y", [1, 2, 3, 4])
30-
pm.Data("z", [100, 190, 310, 405])
27+
with Model():
28+
x = Normal("x", 100, 5)
29+
y = Data("y", [1, 2, 3, 4])
30+
Data("z", [100, 190, 310, 405])
3131

32-
pm.Normal("L", mu=x, sigma=0.1, observed=y)
32+
Normal("L", mu=x, sigma=0.1, observed=y)
3333

3434
kwargs = {
3535
"nuts_sampler": nuts_sampler,
@@ -41,12 +41,12 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
4141
"initvals": {"x": 0.0},
4242
}
4343

44-
idata1 = pm.sample(**kwargs)
45-
idata2 = pm.sample(**kwargs)
44+
idata1 = sample(**kwargs)
45+
idata2 = sample(**kwargs)
4646

4747
reference_kwargs = kwargs.copy()
4848
reference_kwargs["nuts_sampler"] = "pymc"
49-
idata_reference = pm.sample(**reference_kwargs)
49+
idata_reference = sample(**reference_kwargs)
5050

5151
warns = {
5252
(warn.category, warn.message.args[0])
@@ -75,9 +75,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
7575

7676

7777
def test_step_args():
78-
with pm.Model() as model:
79-
a = pm.Normal("a")
80-
idata = pm.sample(
78+
with Model() as model:
79+
a = Normal("a")
80+
idata = sample(
8181
nuts_sampler="numpyro",
8282
target_accept=0.5,
8383
nuts={"max_treedepth": 10},
@@ -108,17 +108,17 @@ def test_sample_var_names(nuts_sampler):
108108
coords = {"group": group_values}
109109

110110
# Create model
111-
with pm.Model(coords=coords) as model:
112-
b_group = pm.Normal("b_group", dims="group")
113-
b_x = pm.Normal("b_x")
114-
mu = pm.Deterministic("mu", b_group[group_idx] + b_x * x)
115-
sigma = pm.HalfNormal("sigma")
116-
pm.Normal("y", mu=mu, sigma=sigma, observed=y)
111+
with Model(coords=coords) as model:
112+
b_group = Normal("b_group", dims="group")
113+
b_x = Normal("b_x")
114+
mu = Deterministic("mu", b_group[group_idx] + b_x * x)
115+
sigma = HalfNormal("sigma")
116+
Normal("y", mu=mu, sigma=sigma, observed=y)
117117

118118
# Sample with and without var_names, but always with the same seed
119119
with model:
120-
idata_1 = pm.sample(tune=100, draws=100, random_seed=seed, **kwargs)
121-
idata_2 = pm.sample(
120+
idata_1 = sample(tune=100, draws=100, random_seed=seed, **kwargs)
121+
idata_2 = sample(
122122
tune=100, draws=100, var_names=["b_group", "b_x", "sigma"], random_seed=seed, **kwargs
123123
)
124124

0 commit comments

Comments
 (0)