Skip to content

Commit 7b8112a

Browse files
committed
Sketch of WALNUTS sampler
1 parent 3ae5095 commit 7b8112a

File tree

3 files changed

+289
-1
lines changed

3 files changed

+289
-1
lines changed

pymc/step_methods/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Step methods."""
1616

1717
from pymc.step_methods.compound import BlockedStep, CompoundStep
18-
from pymc.step_methods.hmc import NUTS, HamiltonianMC
18+
from pymc.step_methods.hmc import NUTS, WALNUTS, HamiltonianMC
1919
from pymc.step_methods.metropolis import (
2020
BinaryGibbsMetropolis,
2121
BinaryMetropolis,
@@ -35,6 +35,7 @@
3535
# Other step methods can be added by appending to this list
3636
STEP_METHODS: list[type[BlockedStep]] = [
3737
NUTS,
38+
WALNUTS,
3839
HamiltonianMC,
3940
Metropolis,
4041
BinaryMetropolis,

pymc/step_methods/hmc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616

1717
from pymc.step_methods.hmc.hmc import HamiltonianMC
1818
from pymc.step_methods.hmc.nuts import NUTS
19+
from pymc.step_methods.hmc.walnuts import WALNUTS

pymc/step_methods/hmc/walnuts.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any, NamedTuple
18+
19+
import numpy as np
20+
21+
from pymc.stats.convergence import SamplerWarning
22+
from pymc.step_methods.compound import Competence
23+
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
24+
from pymc.step_methods.hmc.integration import IntegrationError, State
25+
from pymc.vartypes import continuous_types
26+
27+
__all__ = ["WALNUTS"]
28+
29+
30+
class WalnutsStepData(NamedTuple):
31+
"""State during adaptive integration."""
32+
33+
state: State
34+
n_steps: int
35+
energy_error: float
36+
37+
38+
class WalnutsTree:
39+
"""Binary tree for WALNUTS algorithm.
40+
41+
Similar to NUTS tree but with adaptive step size within orbits.
42+
"""
43+
44+
def __init__(
45+
self,
46+
integrator,
47+
start: State,
48+
step_size: float,
49+
Emax: float,
50+
max_error: float,
51+
rng: np.random.Generator,
52+
):
53+
self.integrator = integrator
54+
self.start = start
55+
self.step_size = step_size
56+
self.Emax = Emax
57+
self.max_error = max_error
58+
self.rng = rng
59+
60+
self.left = self.right = start
61+
self.depth = 0
62+
self.n_proposals = 0
63+
self.max_energy_error = 0.0
64+
self.sum_accept_stat = 0.0
65+
66+
# WALNUTS-specific
67+
self.n_steps_total = 0
68+
self.n_stable_steps = 0
69+
70+
def _find_stable_steps(self, state: State, direction: int) -> tuple[bool, int]:
71+
"""Find minimum number of steps for stable integration."""
72+
initial_energy = state.energy
73+
74+
# Try powers of 2: 1, 2, 4, 8
75+
for n in range(4): # Simplified range
76+
n_steps = 2**n
77+
test_state = state
78+
max_error = 0.0
79+
80+
try:
81+
for _ in range(n_steps):
82+
test_state = self.integrator.step(
83+
direction * self.step_size / n_steps, test_state
84+
)
85+
energy_error = abs(test_state.energy - initial_energy)
86+
max_error = max(max_error, energy_error)
87+
88+
if max_error > self.max_error:
89+
break
90+
91+
if max_error <= self.max_error:
92+
return True, n_steps
93+
94+
except IntegrationError:
95+
continue
96+
97+
return False, 1
98+
99+
def _extend_adaptive(self, state: State, direction: int) -> tuple[State | None, bool, bool]:
100+
"""Extend tree with adaptive step size."""
101+
# Find stable number of steps
102+
is_stable, n_steps = self._find_stable_steps(state, direction)
103+
104+
if not is_stable:
105+
return None, True, False # diverged
106+
107+
# Perform integration with adaptive steps
108+
current_state = state
109+
actual_step = direction * self.step_size / n_steps
110+
111+
try:
112+
for _ in range(n_steps):
113+
current_state = self.integrator.step(actual_step, current_state)
114+
115+
energy_error = abs(current_state.energy - self.start.energy)
116+
if energy_error > self.Emax:
117+
return None, True, False # diverged
118+
119+
self.max_energy_error = max(self.max_energy_error, energy_error)
120+
121+
self.n_proposals += 1
122+
self.n_steps_total += n_steps
123+
self.n_stable_steps += n_steps
124+
125+
# Acceptance statistic
126+
accept_stat = min(1.0, np.exp(self.start.energy - current_state.energy))
127+
self.sum_accept_stat += accept_stat
128+
129+
return current_state, False, False
130+
131+
except IntegrationError:
132+
return None, True, False
133+
134+
def extend(self, direction: int) -> tuple[DivergenceInfo | None, bool]:
135+
"""Extend the tree in given direction."""
136+
if direction > 0:
137+
new_state, diverged, turning = self._extend_adaptive(self.right, direction)
138+
if not diverged and new_state is not None:
139+
self.right = new_state
140+
else:
141+
new_state, diverged, turning = self._extend_adaptive(self.left, direction)
142+
if not diverged and new_state is not None:
143+
self.left = new_state
144+
145+
self.depth += 1
146+
147+
divergence_info = None
148+
if diverged:
149+
divergence_info = DivergenceInfo(
150+
"Energy error exceeded threshold in WALNUTS",
151+
None,
152+
self.left if direction < 0 else self.right,
153+
None,
154+
)
155+
return divergence_info, False
156+
157+
# Check for U-turn
158+
turning = False
159+
if new_state is not None:
160+
delta_q = self.right.q.data - self.left.q.data
161+
turning = np.dot(self.left.p, delta_q) <= 0 or np.dot(self.right.p, delta_q) <= 0
162+
163+
return divergence_info, turning
164+
165+
def get_proposal(self) -> State:
166+
"""Get proposal state (currently just the right endpoint)."""
167+
return self.right if self.rng.random() < 0.5 else self.left
168+
169+
def stats(self) -> dict[str, Any]:
170+
"""Get tree statistics."""
171+
mean_accept = self.sum_accept_stat / max(1, self.n_proposals)
172+
return {
173+
"depth": self.depth,
174+
"mean_tree_accept": mean_accept,
175+
"energy_error": self.right.energy - self.start.energy,
176+
"energy": self.right.energy,
177+
"tree_size": self.n_proposals,
178+
"max_energy_error": self.max_energy_error,
179+
"model_logp": self.right.model_logp,
180+
"index_in_trajectory": self.right.index_in_trajectory,
181+
"n_steps_total": self.n_steps_total,
182+
"avg_steps_per_proposal": self.n_steps_total / max(1, self.n_proposals),
183+
"largest_eigval": np.nan,
184+
"smallest_eigval": np.nan,
185+
}
186+
187+
188+
class WALNUTS(BaseHMC):
189+
"""Within-orbit Adaptive Step-length No-U-Turn Sampler.
190+
191+
WALNUTS extends NUTS by adapting the integration step size within
192+
each trajectory. This can improve numerical stability in models
193+
with varying curvature.
194+
195+
Parameters
196+
----------
197+
vars : list, optional
198+
Variables to sample. If None, all continuous variables in the model.
199+
max_error : float, default=1.0
200+
Maximum allowed Hamiltonian error for adaptive steps.
201+
max_treedepth : int, default=10
202+
Maximum depth of the binary tree.
203+
early_max_treedepth : int, default=8
204+
Maximum depth during tuning phase.
205+
**kwargs
206+
Additional arguments passed to BaseHMC.
207+
"""
208+
209+
name = "walnuts"
210+
211+
default_blocked = True
212+
213+
stats_dtypes_shapes = {
214+
"depth": (np.int64, []),
215+
"step_size": (np.float64, []),
216+
"tune": (bool, []),
217+
"mean_tree_accept": (np.float64, []),
218+
"step_size_bar": (np.float64, []),
219+
"tree_size": (np.float64, []),
220+
"diverging": (bool, []),
221+
"energy_error": (np.float64, []),
222+
"energy": (np.float64, []),
223+
"max_energy_error": (np.float64, []),
224+
"model_logp": (np.float64, []),
225+
"process_time_diff": (np.float64, []),
226+
"perf_counter_diff": (np.float64, []),
227+
"perf_counter_start": (np.float64, []),
228+
"largest_eigval": (np.float64, []),
229+
"smallest_eigval": (np.float64, []),
230+
"index_in_trajectory": (np.int64, []),
231+
"reached_max_treedepth": (bool, []),
232+
"warning": (SamplerWarning, None),
233+
"n_steps_total": (np.int64, []),
234+
"avg_steps_per_proposal": (np.float64, []),
235+
}
236+
237+
def __init__(
238+
self,
239+
vars=None,
240+
max_error=1.0,
241+
max_treedepth=10,
242+
early_max_treedepth=8,
243+
**kwargs,
244+
):
245+
"""Initialize WALNUTS sampler."""
246+
self.max_error = max_error
247+
self.max_treedepth = max_treedepth
248+
self.early_max_treedepth = early_max_treedepth
249+
250+
super().__init__(vars, **kwargs)
251+
252+
def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
253+
"""Perform a single WALNUTS iteration."""
254+
if self.tune and self.iter_count < 200:
255+
max_treedepth = self.early_max_treedepth
256+
else:
257+
max_treedepth = self.max_treedepth
258+
259+
tree = WalnutsTree(self.integrator, start, step_size, self.Emax, self.max_error, self.rng)
260+
261+
reached_max_treedepth = False
262+
divergence_info = None
263+
for _ in range(max_treedepth):
264+
direction = (self.rng.random() < 0.5) * 2 - 1
265+
divergence_info, turning = tree.extend(direction)
266+
267+
if divergence_info or turning:
268+
break
269+
else: # no-break
270+
reached_max_treedepth = not self.tune
271+
272+
stats = tree.stats()
273+
stats["reached_max_treedepth"] = reached_max_treedepth
274+
275+
# Get proposal from tree
276+
proposal = tree.get_proposal()
277+
mean_accept = stats["mean_tree_accept"]
278+
279+
return HMCStepData(proposal, mean_accept, divergence_info, stats)
280+
281+
@staticmethod
282+
def competence(var, has_grad):
283+
"""Check if WALNUTS can sample this variable."""
284+
if var.dtype in continuous_types and has_grad:
285+
return Competence.COMPATIBLE
286+
return Competence.INCOMPATIBLE

0 commit comments

Comments
 (0)