|
| 1 | +# Copyright 2024 - present The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +from typing import Any, NamedTuple |
| 18 | + |
| 19 | +import numpy as np |
| 20 | + |
| 21 | +from pymc.stats.convergence import SamplerWarning |
| 22 | +from pymc.step_methods.compound import Competence |
| 23 | +from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData |
| 24 | +from pymc.step_methods.hmc.integration import IntegrationError, State |
| 25 | +from pymc.vartypes import continuous_types |
| 26 | + |
| 27 | +__all__ = ["WALNUTS"] |
| 28 | + |
| 29 | + |
| 30 | +class WalnutsStepData(NamedTuple): |
| 31 | + """State during adaptive integration.""" |
| 32 | + |
| 33 | + state: State |
| 34 | + n_steps: int |
| 35 | + energy_error: float |
| 36 | + |
| 37 | + |
| 38 | +class WalnutsTree: |
| 39 | + """Binary tree for WALNUTS algorithm. |
| 40 | +
|
| 41 | + Similar to NUTS tree but with adaptive step size within orbits. |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + integrator, |
| 47 | + start: State, |
| 48 | + step_size: float, |
| 49 | + Emax: float, |
| 50 | + max_error: float, |
| 51 | + rng: np.random.Generator, |
| 52 | + ): |
| 53 | + self.integrator = integrator |
| 54 | + self.start = start |
| 55 | + self.step_size = step_size |
| 56 | + self.Emax = Emax |
| 57 | + self.max_error = max_error |
| 58 | + self.rng = rng |
| 59 | + |
| 60 | + self.left = self.right = start |
| 61 | + self.depth = 0 |
| 62 | + self.n_proposals = 0 |
| 63 | + self.max_energy_error = 0.0 |
| 64 | + self.sum_accept_stat = 0.0 |
| 65 | + |
| 66 | + # WALNUTS-specific |
| 67 | + self.n_steps_total = 0 |
| 68 | + self.n_stable_steps = 0 |
| 69 | + |
| 70 | + def _find_stable_steps(self, state: State, direction: int) -> tuple[bool, int]: |
| 71 | + """Find minimum number of steps for stable integration.""" |
| 72 | + initial_energy = state.energy |
| 73 | + |
| 74 | + # Try powers of 2: 1, 2, 4, 8 |
| 75 | + for n in range(4): # Simplified range |
| 76 | + n_steps = 2**n |
| 77 | + test_state = state |
| 78 | + max_error = 0.0 |
| 79 | + |
| 80 | + try: |
| 81 | + for _ in range(n_steps): |
| 82 | + test_state = self.integrator.step( |
| 83 | + direction * self.step_size / n_steps, test_state |
| 84 | + ) |
| 85 | + energy_error = abs(test_state.energy - initial_energy) |
| 86 | + max_error = max(max_error, energy_error) |
| 87 | + |
| 88 | + if max_error > self.max_error: |
| 89 | + break |
| 90 | + |
| 91 | + if max_error <= self.max_error: |
| 92 | + return True, n_steps |
| 93 | + |
| 94 | + except IntegrationError: |
| 95 | + continue |
| 96 | + |
| 97 | + return False, 1 |
| 98 | + |
| 99 | + def _extend_adaptive(self, state: State, direction: int) -> tuple[State | None, bool, bool]: |
| 100 | + """Extend tree with adaptive step size.""" |
| 101 | + # Find stable number of steps |
| 102 | + is_stable, n_steps = self._find_stable_steps(state, direction) |
| 103 | + |
| 104 | + if not is_stable: |
| 105 | + return None, True, False # diverged |
| 106 | + |
| 107 | + # Perform integration with adaptive steps |
| 108 | + current_state = state |
| 109 | + actual_step = direction * self.step_size / n_steps |
| 110 | + |
| 111 | + try: |
| 112 | + for _ in range(n_steps): |
| 113 | + current_state = self.integrator.step(actual_step, current_state) |
| 114 | + |
| 115 | + energy_error = abs(current_state.energy - self.start.energy) |
| 116 | + if energy_error > self.Emax: |
| 117 | + return None, True, False # diverged |
| 118 | + |
| 119 | + self.max_energy_error = max(self.max_energy_error, energy_error) |
| 120 | + |
| 121 | + self.n_proposals += 1 |
| 122 | + self.n_steps_total += n_steps |
| 123 | + self.n_stable_steps += n_steps |
| 124 | + |
| 125 | + # Acceptance statistic |
| 126 | + accept_stat = min(1.0, np.exp(self.start.energy - current_state.energy)) |
| 127 | + self.sum_accept_stat += accept_stat |
| 128 | + |
| 129 | + return current_state, False, False |
| 130 | + |
| 131 | + except IntegrationError: |
| 132 | + return None, True, False |
| 133 | + |
| 134 | + def extend(self, direction: int) -> tuple[DivergenceInfo | None, bool]: |
| 135 | + """Extend the tree in given direction.""" |
| 136 | + if direction > 0: |
| 137 | + new_state, diverged, turning = self._extend_adaptive(self.right, direction) |
| 138 | + if not diverged and new_state is not None: |
| 139 | + self.right = new_state |
| 140 | + else: |
| 141 | + new_state, diverged, turning = self._extend_adaptive(self.left, direction) |
| 142 | + if not diverged and new_state is not None: |
| 143 | + self.left = new_state |
| 144 | + |
| 145 | + self.depth += 1 |
| 146 | + |
| 147 | + divergence_info = None |
| 148 | + if diverged: |
| 149 | + divergence_info = DivergenceInfo( |
| 150 | + "Energy error exceeded threshold in WALNUTS", |
| 151 | + None, |
| 152 | + self.left if direction < 0 else self.right, |
| 153 | + None, |
| 154 | + ) |
| 155 | + return divergence_info, False |
| 156 | + |
| 157 | + # Check for U-turn |
| 158 | + turning = False |
| 159 | + if new_state is not None: |
| 160 | + delta_q = self.right.q.data - self.left.q.data |
| 161 | + turning = np.dot(self.left.p, delta_q) <= 0 or np.dot(self.right.p, delta_q) <= 0 |
| 162 | + |
| 163 | + return divergence_info, turning |
| 164 | + |
| 165 | + def get_proposal(self) -> State: |
| 166 | + """Get proposal state (currently just the right endpoint).""" |
| 167 | + return self.right if self.rng.random() < 0.5 else self.left |
| 168 | + |
| 169 | + def stats(self) -> dict[str, Any]: |
| 170 | + """Get tree statistics.""" |
| 171 | + mean_accept = self.sum_accept_stat / max(1, self.n_proposals) |
| 172 | + return { |
| 173 | + "depth": self.depth, |
| 174 | + "mean_tree_accept": mean_accept, |
| 175 | + "energy_error": self.right.energy - self.start.energy, |
| 176 | + "energy": self.right.energy, |
| 177 | + "tree_size": self.n_proposals, |
| 178 | + "max_energy_error": self.max_energy_error, |
| 179 | + "model_logp": self.right.model_logp, |
| 180 | + "index_in_trajectory": self.right.index_in_trajectory, |
| 181 | + "n_steps_total": self.n_steps_total, |
| 182 | + "avg_steps_per_proposal": self.n_steps_total / max(1, self.n_proposals), |
| 183 | + "largest_eigval": np.nan, |
| 184 | + "smallest_eigval": np.nan, |
| 185 | + } |
| 186 | + |
| 187 | + |
| 188 | +class WALNUTS(BaseHMC): |
| 189 | + """Within-orbit Adaptive Step-length No-U-Turn Sampler. |
| 190 | +
|
| 191 | + WALNUTS extends NUTS by adapting the integration step size within |
| 192 | + each trajectory. This can improve numerical stability in models |
| 193 | + with varying curvature. |
| 194 | +
|
| 195 | + Parameters |
| 196 | + ---------- |
| 197 | + vars : list, optional |
| 198 | + Variables to sample. If None, all continuous variables in the model. |
| 199 | + max_error : float, default=1.0 |
| 200 | + Maximum allowed Hamiltonian error for adaptive steps. |
| 201 | + max_treedepth : int, default=10 |
| 202 | + Maximum depth of the binary tree. |
| 203 | + early_max_treedepth : int, default=8 |
| 204 | + Maximum depth during tuning phase. |
| 205 | + **kwargs |
| 206 | + Additional arguments passed to BaseHMC. |
| 207 | + """ |
| 208 | + |
| 209 | + name = "walnuts" |
| 210 | + |
| 211 | + default_blocked = True |
| 212 | + |
| 213 | + stats_dtypes_shapes = { |
| 214 | + "depth": (np.int64, []), |
| 215 | + "step_size": (np.float64, []), |
| 216 | + "tune": (bool, []), |
| 217 | + "mean_tree_accept": (np.float64, []), |
| 218 | + "step_size_bar": (np.float64, []), |
| 219 | + "tree_size": (np.float64, []), |
| 220 | + "diverging": (bool, []), |
| 221 | + "energy_error": (np.float64, []), |
| 222 | + "energy": (np.float64, []), |
| 223 | + "max_energy_error": (np.float64, []), |
| 224 | + "model_logp": (np.float64, []), |
| 225 | + "process_time_diff": (np.float64, []), |
| 226 | + "perf_counter_diff": (np.float64, []), |
| 227 | + "perf_counter_start": (np.float64, []), |
| 228 | + "largest_eigval": (np.float64, []), |
| 229 | + "smallest_eigval": (np.float64, []), |
| 230 | + "index_in_trajectory": (np.int64, []), |
| 231 | + "reached_max_treedepth": (bool, []), |
| 232 | + "warning": (SamplerWarning, None), |
| 233 | + "n_steps_total": (np.int64, []), |
| 234 | + "avg_steps_per_proposal": (np.float64, []), |
| 235 | + } |
| 236 | + |
| 237 | + def __init__( |
| 238 | + self, |
| 239 | + vars=None, |
| 240 | + max_error=1.0, |
| 241 | + max_treedepth=10, |
| 242 | + early_max_treedepth=8, |
| 243 | + **kwargs, |
| 244 | + ): |
| 245 | + """Initialize WALNUTS sampler.""" |
| 246 | + self.max_error = max_error |
| 247 | + self.max_treedepth = max_treedepth |
| 248 | + self.early_max_treedepth = early_max_treedepth |
| 249 | + |
| 250 | + super().__init__(vars, **kwargs) |
| 251 | + |
| 252 | + def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData: |
| 253 | + """Perform a single WALNUTS iteration.""" |
| 254 | + if self.tune and self.iter_count < 200: |
| 255 | + max_treedepth = self.early_max_treedepth |
| 256 | + else: |
| 257 | + max_treedepth = self.max_treedepth |
| 258 | + |
| 259 | + tree = WalnutsTree(self.integrator, start, step_size, self.Emax, self.max_error, self.rng) |
| 260 | + |
| 261 | + reached_max_treedepth = False |
| 262 | + divergence_info = None |
| 263 | + for _ in range(max_treedepth): |
| 264 | + direction = (self.rng.random() < 0.5) * 2 - 1 |
| 265 | + divergence_info, turning = tree.extend(direction) |
| 266 | + |
| 267 | + if divergence_info or turning: |
| 268 | + break |
| 269 | + else: # no-break |
| 270 | + reached_max_treedepth = not self.tune |
| 271 | + |
| 272 | + stats = tree.stats() |
| 273 | + stats["reached_max_treedepth"] = reached_max_treedepth |
| 274 | + |
| 275 | + # Get proposal from tree |
| 276 | + proposal = tree.get_proposal() |
| 277 | + mean_accept = stats["mean_tree_accept"] |
| 278 | + |
| 279 | + return HMCStepData(proposal, mean_accept, divergence_info, stats) |
| 280 | + |
| 281 | + @staticmethod |
| 282 | + def competence(var, has_grad): |
| 283 | + """Check if WALNUTS can sample this variable.""" |
| 284 | + if var.dtype in continuous_types and has_grad: |
| 285 | + return Competence.COMPATIBLE |
| 286 | + return Competence.INCOMPATIBLE |
0 commit comments