Skip to content

Commit 4e45516

Browse files
ArmavicaricardoV94
authored andcommitted
Distribute test_step into step_methods
1 parent 8adcd4a commit 4e45516

File tree

9 files changed

+829
-724
lines changed

9 files changed

+829
-724
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ jobs:
148148
python-version: ["3.8"]
149149
test-subset:
150150
- pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py
151-
- pymc/tests/test_model.py pymc/tests/test_step.py
151+
- pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
152152
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/test_smc.py pymc/tests/test_parallel_sampling.py
153153
- pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py
154154

pymc/tests/helpers.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import shutil
17+
import tempfile
18+
import warnings
1619

1720
from logging.handlers import BufferingHandler
1821

@@ -24,7 +27,11 @@
2427
from aesara.graph.rewriting.basic import in2out
2528
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
2629

30+
import pymc as pm
31+
2732
from pymc.aesaraf import at_rng, local_check_parameter_to_ninf_switch, set_at_rng
33+
from pymc.tests.checks import close_to
34+
from pymc.tests.models import mv_simple, mv_simple_coarse
2835

2936

3037
class SeededTest:
@@ -148,3 +155,66 @@ def assert_random_state_equal(state1, state2):
148155
(in2out(local_check_parameter_to_ninf_switch), -1)
149156
)
150157
)
158+
159+
160+
class StepMethodTester:
161+
def setup_class(self):
162+
self.temp_dir = tempfile.mkdtemp()
163+
164+
def teardown_class(self):
165+
shutil.rmtree(self.temp_dir)
166+
167+
def check_stat(self, check, idata, name):
168+
group = idata.posterior
169+
for (var, stat, value, bound) in check:
170+
s = stat(group[var].sel(chain=0), axis=0)
171+
close_to(s, value, bound, name)
172+
173+
def check_stat_dtype(self, step, idata):
174+
# TODO: This check does not confirm the announced dtypes are correct as the
175+
# sampling machinery will convert them automatically.
176+
for stats_dtypes in getattr(step, "stats_dtypes", []):
177+
for stat, dtype in stats_dtypes.items():
178+
if stat == "tune":
179+
continue
180+
assert idata.sample_stats[stat].dtype == np.dtype(dtype)
181+
182+
def step_continuous(self, step_fn, draws):
183+
start, model, (mu, C) = mv_simple()
184+
unc = np.diag(C) ** 0.5
185+
check = (("x", np.mean, mu, unc / 10), ("x", np.std, unc, unc / 10))
186+
_, model_coarse, _ = mv_simple_coarse()
187+
with model:
188+
step = step_fn(C, model_coarse)
189+
with warnings.catch_warnings():
190+
warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning)
191+
idata = pm.sample(
192+
tune=1000,
193+
draws=draws,
194+
chains=1,
195+
step=step,
196+
initvals=start,
197+
model=model,
198+
random_seed=1,
199+
)
200+
self.check_stat(check, idata, step.__class__.__name__)
201+
self.check_stat_dtype(idata, step)
202+
203+
204+
class RVsAssignmentStepsTester:
205+
"""
206+
Test that step methods convert input RVs to respective value vars
207+
Step methods are tested with one and two variables to cover compound
208+
the special branches in `BlockedStep.__new__`
209+
"""
210+
211+
def continuous_steps(self, step, step_kwargs):
212+
with pm.Model() as m:
213+
c1 = pm.HalfNormal("c1")
214+
c2 = pm.HalfNormal("c2")
215+
216+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
217+
assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
218+
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
219+
step([c1, c2], **step_kwargs).vars
220+
)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from pymc.step_methods.hmc import HamiltonianMC
18+
from pymc.tests.helpers import RVsAssignmentStepsTester, StepMethodTester
19+
20+
21+
class TestStepHamiltonianMC(StepMethodTester):
22+
@pytest.mark.parametrize(
23+
"step_fn, draws",
24+
[
25+
(lambda C, _: HamiltonianMC(scaling=C, is_cov=True, blocked=False), 1000),
26+
(lambda C, _: HamiltonianMC(scaling=C, is_cov=True), 1000),
27+
],
28+
)
29+
def test_step_continuous(self, step_fn, draws):
30+
self.step_continuous(step_fn, draws)
31+
32+
33+
class TestRVsAssignmentHamiltonianMC(RVsAssignmentStepsTester):
34+
@pytest.mark.parametrize("step, step_kwargs", [(HamiltonianMC, {})])
35+
def test_continuous_steps(self, step, step_kwargs):
36+
self.continuous_steps(step, step_kwargs)

pymc/tests/step_methods/hmc/test_nuts.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
16+
import warnings
17+
18+
import aesara.tensor as at
19+
import numpy as np
1520
import pytest
1621

22+
import pymc as pm
23+
24+
from pymc.aesaraf import floatX
25+
from pymc.exceptions import SamplingError
26+
from pymc.step_methods.hmc import NUTS
1727
from pymc.tests import sampler_fixtures as sf
28+
from pymc.tests.helpers import RVsAssignmentStepsTester, StepMethodTester
1829

1930

2031
class TestNUTSUniform(sf.NutsFixture, sf.UniformFixture):
@@ -81,3 +92,116 @@ class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture):
8192
burn = 0
8293
chains = 2
8394
min_n_eff = 200
95+
96+
97+
class TestNutsCheckTrace:
98+
def test_multiple_samplers(self, caplog):
99+
with pm.Model():
100+
prob = pm.Beta("prob", alpha=5.0, beta=3.0)
101+
pm.Binomial("outcome", n=1, p=prob)
102+
caplog.clear()
103+
with warnings.catch_warnings():
104+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
105+
pm.sample(3, tune=2, discard_tuned_samples=False, n_init=None, chains=1)
106+
messages = [msg.msg for msg in caplog.records]
107+
assert all("boolean index did not" not in msg for msg in messages)
108+
109+
def test_bad_init_nonparallel(self):
110+
with pm.Model():
111+
pm.HalfNormal("a", sigma=1, initval=-1, transform=None)
112+
with pytest.raises(SamplingError) as error:
113+
pm.sample(chains=1, random_seed=1)
114+
error.match("Initial evaluation")
115+
116+
@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher")
117+
def test_bad_init_parallel(self):
118+
with pm.Model():
119+
pm.HalfNormal("a", sigma=1, initval=-1, transform=None)
120+
with pytest.raises(SamplingError) as error:
121+
pm.sample(cores=2, random_seed=1)
122+
error.match("Initial evaluation")
123+
124+
def test_linalg(self, caplog):
125+
with pm.Model():
126+
a = pm.Normal("a", size=2, initval=floatX(np.zeros(2)))
127+
a = at.switch(a > 0, np.inf, a)
128+
b = at.slinalg.solve(floatX(np.eye(2)), a, check_finite=False)
129+
pm.Normal("c", mu=b, size=2, initval=floatX(np.r_[0.0, 0.0]))
130+
caplog.clear()
131+
with warnings.catch_warnings():
132+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
133+
trace = pm.sample(20, tune=5, chains=2, return_inferencedata=False, random_seed=526)
134+
warns = [msg.msg for msg in caplog.records]
135+
assert np.any(trace["diverging"])
136+
assert (
137+
any("divergence after tuning" in warn for warn in warns)
138+
or any("divergences after tuning" in warn for warn in warns)
139+
or any("only diverging samples" in warn for warn in warns)
140+
)
141+
142+
with pytest.raises(ValueError) as error:
143+
trace.report.raise_ok()
144+
error.match("issues during sampling")
145+
146+
assert not trace.report.ok
147+
148+
def test_sampler_stats(self):
149+
with pm.Model() as model:
150+
pm.Normal("x", mu=0, sigma=1)
151+
with warnings.catch_warnings():
152+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
153+
trace = pm.sample(draws=10, tune=1, chains=1, return_inferencedata=False)
154+
155+
# Assert stats exist and have the correct shape.
156+
expected_stat_names = {
157+
"depth",
158+
"diverging",
159+
"energy",
160+
"energy_error",
161+
"model_logp",
162+
"max_energy_error",
163+
"mean_tree_accept",
164+
"step_size",
165+
"step_size_bar",
166+
"tree_size",
167+
"tune",
168+
"perf_counter_diff",
169+
"perf_counter_start",
170+
"process_time_diff",
171+
"index_in_trajectory",
172+
"largest_eigval",
173+
"smallest_eigval",
174+
}
175+
assert trace.stat_names == expected_stat_names
176+
for varname in trace.stat_names:
177+
assert trace.get_sampler_stats(varname).shape == (10,)
178+
179+
# Assert model logp is computed correctly: computing post-sampling
180+
# and tracking while sampling should give same results.
181+
model_logp_fn = model.compile_logp()
182+
model_logp_ = np.array(
183+
[
184+
model_logp_fn(trace.point(i, chain=c))
185+
for c in trace.chains
186+
for i in range(len(trace))
187+
]
188+
)
189+
assert (trace.model_logp == model_logp_).all()
190+
191+
192+
class TestStepNUTS(StepMethodTester):
193+
@pytest.mark.parametrize(
194+
"step_fn, draws",
195+
[
196+
(lambda C, _: NUTS(scaling=C, is_cov=True, blocked=False), 1000),
197+
(lambda C, _: NUTS(scaling=C, is_cov=True), 1000),
198+
],
199+
)
200+
def test_step_continuous(self, step_fn, draws):
201+
self.step_continuous(step_fn, draws)
202+
203+
204+
class TestRVsAssignmentNUTS(RVsAssignmentStepsTester):
205+
@pytest.mark.parametrize("step, step_kwargs", [(NUTS, {})])
206+
def test_continuous_steps(self, step, step_kwargs):
207+
self.continuous_steps(step, step_kwargs)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import aesara
16+
import pytest
17+
18+
import pymc as pm
19+
20+
from pymc.step_methods import (
21+
NUTS,
22+
CompoundStep,
23+
DEMetropolis,
24+
HamiltonianMC,
25+
Metropolis,
26+
Slice,
27+
)
28+
from pymc.tests.helpers import StepMethodTester, fast_unstable_sampling_mode
29+
from pymc.tests.models import simple_2model_continuous
30+
31+
32+
class TestCompoundStep:
33+
samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis)
34+
35+
def test_non_blocked(self):
36+
"""Test that samplers correctly create non-blocked compound steps."""
37+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
38+
_, model = simple_2model_continuous()
39+
with model:
40+
for sampler in self.samplers:
41+
assert isinstance(sampler(blocked=False), CompoundStep)
42+
43+
def test_blocked(self):
44+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
45+
_, model = simple_2model_continuous()
46+
with model:
47+
for sampler in self.samplers:
48+
sampler_instance = sampler(blocked=True)
49+
assert not isinstance(sampler_instance, CompoundStep)
50+
assert isinstance(sampler_instance, sampler)
51+
52+
def test_name(self):
53+
with pm.Model() as m:
54+
c1 = pm.HalfNormal("c1")
55+
c2 = pm.HalfNormal("c2")
56+
57+
step1 = NUTS([c1])
58+
step2 = Slice([c2])
59+
step = CompoundStep([step1, step2])
60+
assert step.name == "Compound[nuts, slice]"
61+
62+
63+
class TestStepCompound(StepMethodTester):
64+
@pytest.mark.parametrize(
65+
"step_fn, draws",
66+
[
67+
(
68+
lambda C, _: CompoundStep(
69+
[
70+
HamiltonianMC(scaling=C, is_cov=True),
71+
HamiltonianMC(scaling=C, is_cov=True, blocked=False),
72+
]
73+
),
74+
1000,
75+
),
76+
],
77+
ids=str,
78+
)
79+
def test_step_continuous(self, step_fn, draws):
80+
self.step_continuous(step_fn, draws)
81+
82+
83+
class TestRVsAssignmentCompound:
84+
def test_compound_step(self):
85+
with pm.Model() as m:
86+
c1 = pm.HalfNormal("c1")
87+
c2 = pm.HalfNormal("c2")
88+
89+
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
90+
step1 = NUTS([c1])
91+
step2 = NUTS([c2])
92+
step = CompoundStep([step1, step2])
93+
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars)

0 commit comments

Comments
 (0)