Skip to content

Commit 69de1e6

Browse files
committed
WALNUTS unit tests
1 parent 7b8112a commit 69de1e6

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
17+
import numpy as np
18+
import numpy.testing as npt
19+
import pytest
20+
21+
import pymc as pm
22+
23+
from pymc.exceptions import SamplingError
24+
from pymc.step_methods.hmc import WALNUTS
25+
from tests import sampler_fixtures as sf
26+
from tests.helpers import RVsAssignmentStepsTester, StepMethodTester
27+
28+
29+
class WalnutsFixture(sf.BaseSampler):
30+
@classmethod
31+
def make_step(cls):
32+
args = {}
33+
if hasattr(cls, "step_args"):
34+
args.update(cls.step_args)
35+
if "scaling" not in args:
36+
_, step = pm.sampling.mcmc.init_nuts(n_init=10000, **args)
37+
# Replace the NUTS step with WALNUTS but keep the same mass matrix
38+
step = pm.WALNUTS(potential=step.potential, target_accept=step.target_accept, **args)
39+
else:
40+
step = pm.WALNUTS(**args)
41+
return step
42+
43+
def test_target_accept(self):
44+
accept = self.trace[self.burn :]["mean_tree_accept"]
45+
npt.assert_allclose(accept.mean(), self.step.target_accept, 1)
46+
47+
48+
# Basic distribution tests - these are relevant for WALNUTS since it's a general HMC sampler
49+
class TestWALNUTSUniform(WalnutsFixture, sf.UniformFixture):
50+
n_samples = 5000 # Reduced for faster testing
51+
tune = 500
52+
burn = 500
53+
chains = 2
54+
min_n_eff = 2000
55+
rtol = 0.1
56+
atol = 0.05
57+
step_args = {"random_seed": 202010}
58+
59+
60+
class TestWALNUTSNormal(WalnutsFixture, sf.NormalFixture):
61+
n_samples = 5000 # Reduced for faster testing
62+
tune = 500
63+
burn = 0
64+
chains = 2
65+
min_n_eff = 4000
66+
rtol = 0.1
67+
atol = 0.05
68+
step_args = {"random_seed": 123456}
69+
70+
71+
# WALNUTS-specific functionality tests
72+
class TestWalnutsSpecific:
73+
def test_walnuts_specific_stats(self):
74+
"""Test that WALNUTS produces its specific statistics."""
75+
with pm.Model():
76+
pm.Normal("x", mu=0, sigma=1)
77+
with warnings.catch_warnings():
78+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
79+
trace = pm.sample(
80+
draws=10, tune=5, chains=1, return_inferencedata=False, step=pm.WALNUTS()
81+
)
82+
83+
# Check WALNUTS-specific stats are present
84+
walnuts_stats = ["n_steps_total", "avg_steps_per_proposal"]
85+
for stat in walnuts_stats:
86+
assert stat in trace.stat_names, f"WALNUTS-specific stat '{stat}' missing"
87+
stats_values = trace.get_sampler_stats(stat)
88+
assert stats_values.shape == (10,), f"Wrong shape for {stat}"
89+
assert np.all(stats_values >= 0), f"{stat} should be non-negative"
90+
91+
# Check that n_steps_total makes sense relative to tree_size
92+
n_steps = trace.get_sampler_stats("n_steps_total")
93+
tree_size = trace.get_sampler_stats("tree_size")
94+
# n_steps_total should generally be >= tree_size (adaptive steps might use more steps)
95+
assert np.all(n_steps >= tree_size), "n_steps_total should be >= tree_size"
96+
97+
def test_walnuts_parameters(self):
98+
"""Test WALNUTS-specific parameters."""
99+
with pm.Model():
100+
pm.Normal("x", mu=0, sigma=1)
101+
102+
# Test custom max_error parameter
103+
step = pm.WALNUTS(max_error=0.5, max_treedepth=8)
104+
assert step.max_error == 0.5
105+
assert step.max_treedepth == 8
106+
107+
# Test early_max_treedepth
108+
assert hasattr(step, "early_max_treedepth")
109+
110+
def test_bad_init_handling(self):
111+
"""Test that WALNUTS handles bad initialization properly."""
112+
with pm.Model():
113+
pm.HalfNormal("a", sigma=1, initval=-1, default_transform=None)
114+
with pytest.raises(SamplingError) as error:
115+
pm.sample(chains=1, random_seed=1, step=pm.WALNUTS())
116+
error.match("Bad initial energy")
117+
118+
def test_competence_method(self):
119+
"""Test WALNUTS competence for different variable types."""
120+
from pymc.step_methods.compound import Competence
121+
122+
# Mock continuous variable with gradient
123+
class MockVar:
124+
dtype = "float64" # continuous_types contains strings, not dtype objects
125+
126+
var = MockVar()
127+
assert WALNUTS.competence(var, has_grad=True) == Competence.COMPATIBLE
128+
assert WALNUTS.competence(var, has_grad=False) == Competence.INCOMPATIBLE
129+
130+
def test_required_attributes(self):
131+
"""Test that WALNUTS has all required attributes."""
132+
with pm.Model():
133+
pm.Normal("x", mu=0, sigma=1)
134+
step = pm.WALNUTS()
135+
136+
# Check required attributes
137+
assert hasattr(step, "name")
138+
assert step.name == "walnuts"
139+
assert hasattr(step, "default_blocked")
140+
assert step.default_blocked is True
141+
assert hasattr(step, "stats_dtypes_shapes")
142+
143+
# Check WALNUTS-specific stats are defined
144+
required_stats = ["n_steps_total", "avg_steps_per_proposal"]
145+
for stat in required_stats:
146+
assert stat in step.stats_dtypes_shapes
147+
148+
149+
# Test step method functionality
150+
class TestStepWALNUTS(StepMethodTester):
151+
@pytest.mark.parametrize(
152+
"step_fn, draws",
153+
[
154+
(lambda C, _: WALNUTS(scaling=C, is_cov=True, blocked=False), 1000),
155+
(lambda C, _: WALNUTS(scaling=C, is_cov=True), 1000),
156+
],
157+
)
158+
def test_step_continuous(self, step_fn, draws):
159+
self.step_continuous(step_fn, draws)
160+
161+
162+
class TestRVsAssignmentWALNUTS(RVsAssignmentStepsTester):
163+
@pytest.mark.parametrize("step, step_kwargs", [(WALNUTS, {})])
164+
def test_continuous_steps(self, step, step_kwargs):
165+
self.continuous_steps(step, step_kwargs)
166+
167+
168+
def test_walnuts_step_legacy_value_grad_function():
169+
"""Test WALNUTS with legacy value grad function (compatibility test)."""
170+
with pm.Model() as m:
171+
x = pm.Normal("x", shape=(2,))
172+
y = pm.Normal("y", x, shape=(3, 2))
173+
174+
legacy_value_grad_fn = m.logp_dlogp_function(ravel_inputs=False, mode="FAST_COMPILE")
175+
legacy_value_grad_fn.set_extra_values({})
176+
walnuts = WALNUTS(model=m, logp_dlogp_func=legacy_value_grad_fn)
177+
178+
# Confirm it is a function of multiple variables
179+
logp, dlogp = walnuts._logp_dlogp_func([np.zeros((2,)), np.zeros((3, 2))])
180+
np.testing.assert_allclose(dlogp, np.zeros(8))
181+
182+
# Confirm we can perform a WALNUTS step
183+
ip = m.initial_point()
184+
new_ip, _ = walnuts.step(ip)
185+
assert np.all(new_ip["x"] != ip["x"])
186+
assert np.all(new_ip["y"] != ip["y"])

0 commit comments

Comments
 (0)