Skip to content

Commit de61d81

Browse files
ArmavicaricardoV94
authored andcommitted
Merge old test_hmc into new one
1 parent 4e45516 commit de61d81

File tree

3 files changed

+56
-73
lines changed

3 files changed

+56
-73
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
pymc/tests/test_aesaraf.py
4444
pymc/tests/test_math.py
4545
pymc/tests/backends/test_ndarray.py
46-
pymc/tests/test_hmc.py
46+
pymc/tests/step_methods/hmc/test_hmc.py
4747
pymc/tests/test_func_utils.py
4848
pymc/tests/distributions/test_shape_utils.py
4949
pymc/tests/distributions/test_mixture.py

pymc/tests/step_methods/hmc/test_hmc.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
17+
import numpy as np
18+
import numpy.testing as npt
1519
import pytest
1620

21+
import pymc as pm
22+
23+
from pymc.aesaraf import floatX
24+
from pymc.blocking import DictToArrayBijection, RaveledVars
1725
from pymc.step_methods.hmc import HamiltonianMC
26+
from pymc.step_methods.hmc.base_hmc import BaseHMC
27+
from pymc.tests import models
1828
from pymc.tests.helpers import RVsAssignmentStepsTester, StepMethodTester
1929

2030

@@ -34,3 +44,48 @@ class TestRVsAssignmentHamiltonianMC(RVsAssignmentStepsTester):
3444
@pytest.mark.parametrize("step, step_kwargs", [(HamiltonianMC, {})])
3545
def test_continuous_steps(self, step, step_kwargs):
3646
self.continuous_steps(step, step_kwargs)
47+
48+
49+
def test_leapfrog_reversible():
50+
n = 3
51+
np.random.seed(42)
52+
start, model, _ = models.non_normal(n)
53+
size = sum(start[n.name].size for n in model.value_vars)
54+
scaling = floatX(np.random.rand(size))
55+
56+
class HMC(BaseHMC):
57+
def _hamiltonian_step(self, *args, **kwargs):
58+
pass
59+
60+
step = HMC(vars=model.value_vars, model=model, scaling=scaling)
61+
62+
step.integrator._logp_dlogp_func.set_extra_values({})
63+
astart = DictToArrayBijection.map(start)
64+
p = RaveledVars(floatX(step.potential.random()), astart.point_map_info)
65+
q = RaveledVars(floatX(np.random.randn(size)), astart.point_map_info)
66+
start = step.integrator.compute_state(p, q)
67+
for epsilon in [0.01, 0.1]:
68+
for n_steps in [1, 2, 3, 4, 20]:
69+
state = start
70+
for _ in range(n_steps):
71+
state = step.integrator.step(epsilon, state)
72+
for _ in range(n_steps):
73+
state = step.integrator.step(-epsilon, state)
74+
npt.assert_allclose(state.q.data, start.q.data, rtol=1e-5)
75+
npt.assert_allclose(state.p.data, start.p.data, rtol=1e-5)
76+
77+
78+
def test_nuts_tuning():
79+
with pm.Model():
80+
pm.Normal("mu", mu=0, sigma=1)
81+
step = pm.NUTS()
82+
with warnings.catch_warnings():
83+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
84+
idata = pm.sample(
85+
10, step=step, tune=5, discard_tuned_samples=False, progressbar=False, chains=1
86+
)
87+
88+
assert not step.tune
89+
ss_tuned = idata.warmup_sample_stats["step_size"][0, -1]
90+
ss_posterior = idata.sample_stats["step_size"][0, :]
91+
np.testing.assert_array_equal(ss_posterior, ss_tuned)

pymc/tests/test_hmc.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

0 commit comments

Comments
 (0)