12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import warnings
16
+
17
+ import numpy as np
18
+ import numpy .testing as npt
15
19
import pytest
16
20
21
+ import pymc as pm
22
+
23
+ from pymc .aesaraf import floatX
24
+ from pymc .blocking import DictToArrayBijection , RaveledVars
17
25
from pymc .step_methods .hmc import HamiltonianMC
26
+ from pymc .step_methods .hmc .base_hmc import BaseHMC
27
+ from pymc .tests import models
18
28
from pymc .tests .helpers import RVsAssignmentStepsTester , StepMethodTester
19
29
20
30
@@ -34,3 +44,48 @@ class TestRVsAssignmentHamiltonianMC(RVsAssignmentStepsTester):
34
44
@pytest .mark .parametrize ("step, step_kwargs" , [(HamiltonianMC , {})])
35
45
def test_continuous_steps (self , step , step_kwargs ):
36
46
self .continuous_steps (step , step_kwargs )
47
+
48
+
49
+ def test_leapfrog_reversible ():
50
+ n = 3
51
+ np .random .seed (42 )
52
+ start , model , _ = models .non_normal (n )
53
+ size = sum (start [n .name ].size for n in model .value_vars )
54
+ scaling = floatX (np .random .rand (size ))
55
+
56
+ class HMC (BaseHMC ):
57
+ def _hamiltonian_step (self , * args , ** kwargs ):
58
+ pass
59
+
60
+ step = HMC (vars = model .value_vars , model = model , scaling = scaling )
61
+
62
+ step .integrator ._logp_dlogp_func .set_extra_values ({})
63
+ astart = DictToArrayBijection .map (start )
64
+ p = RaveledVars (floatX (step .potential .random ()), astart .point_map_info )
65
+ q = RaveledVars (floatX (np .random .randn (size )), astart .point_map_info )
66
+ start = step .integrator .compute_state (p , q )
67
+ for epsilon in [0.01 , 0.1 ]:
68
+ for n_steps in [1 , 2 , 3 , 4 , 20 ]:
69
+ state = start
70
+ for _ in range (n_steps ):
71
+ state = step .integrator .step (epsilon , state )
72
+ for _ in range (n_steps ):
73
+ state = step .integrator .step (- epsilon , state )
74
+ npt .assert_allclose (state .q .data , start .q .data , rtol = 1e-5 )
75
+ npt .assert_allclose (state .p .data , start .p .data , rtol = 1e-5 )
76
+
77
+
78
+ def test_nuts_tuning ():
79
+ with pm .Model ():
80
+ pm .Normal ("mu" , mu = 0 , sigma = 1 )
81
+ step = pm .NUTS ()
82
+ with warnings .catch_warnings ():
83
+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
84
+ idata = pm .sample (
85
+ 10 , step = step , tune = 5 , discard_tuned_samples = False , progressbar = False , chains = 1
86
+ )
87
+
88
+ assert not step .tune
89
+ ss_tuned = idata .warmup_sample_stats ["step_size" ][0 , - 1 ]
90
+ ss_posterior = idata .sample_stats ["step_size" ][0 , :]
91
+ np .testing .assert_array_equal (ss_posterior , ss_tuned )
0 commit comments