|
| 1 | +# Copyright 2024 - present The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import logging |
| 16 | +import warnings |
| 17 | + |
| 18 | +from typing import Literal |
| 19 | + |
| 20 | +from arviz import InferenceData |
| 21 | + |
| 22 | +from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations |
| 23 | +from pymc.model import Model |
| 24 | +from pymc.step_methods.compound import Competence |
| 25 | +from pymc.step_methods.external.base import ExternalSampler |
| 26 | +from pymc.vartypes import continuous_types |
| 27 | + |
| 28 | +logger = logging.getLogger("pymc") |
| 29 | + |
| 30 | +try: |
| 31 | + import nutpie |
| 32 | + |
| 33 | + # Check if it's actually installed and not just an empty mock module |
| 34 | + NUTPIE_AVAILABLE = hasattr(nutpie, "compile_pymc_model") |
| 35 | +except ImportError: |
| 36 | + NUTPIE_AVAILABLE = False |
| 37 | + |
| 38 | + |
| 39 | +class NutPie(ExternalSampler): |
| 40 | + """NutPie No-U-Turn Sampler. |
| 41 | +
|
| 42 | + This class provides an interface to the NutPie sampler, which is a high-performance |
| 43 | + implementation of the No-U-Turn Sampler (NUTS). Unlike PyMC's native NUTS implementation, |
| 44 | + NutPie samples chains sequentially in a single CPU, which can be more efficient for some |
| 45 | + models. |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + vars : list, optional |
| 50 | + Variables to be sampled |
| 51 | + model : Model, optional |
| 52 | + PyMC model |
| 53 | + backend : {"numba", "jax"}, default="numba" |
| 54 | + Which backend to use for computation |
| 55 | + target_accept : float, default=0.8 |
| 56 | + Target acceptance rate for step size adaptation |
| 57 | + max_treedepth : int, default=10 |
| 58 | + Maximum tree depth for NUTS (passed as 'maxdepth' to NutPie) |
| 59 | + **kwargs |
| 60 | + Additional parameters passed to nutpie.sample() |
| 61 | +
|
| 62 | + Notes |
| 63 | + ----- |
| 64 | + Requires the nutpie package to be installed: |
| 65 | + pip install nutpie |
| 66 | + """ |
| 67 | + |
| 68 | + name = "nutpie" |
| 69 | + |
| 70 | + def __init__( |
| 71 | + self, |
| 72 | + vars=None, |
| 73 | + *, |
| 74 | + model=None, |
| 75 | + backend: Literal["numba", "jax"] = "numba", |
| 76 | + target_accept: float = 0.8, |
| 77 | + max_treedepth: int = 10, |
| 78 | + **kwargs, |
| 79 | + ): |
| 80 | + """Initialize NutPie sampler.""" |
| 81 | + if not NUTPIE_AVAILABLE: |
| 82 | + raise ImportError("nutpie not found. Install it with: pip install nutpie") |
| 83 | + |
| 84 | + super().__init__(vars=vars, model=model) |
| 85 | + |
| 86 | + self.backend = backend |
| 87 | + self.target_accept = target_accept |
| 88 | + self.max_treedepth = max_treedepth |
| 89 | + self.nutpie_kwargs = kwargs |
| 90 | + |
| 91 | + def sample( |
| 92 | + self, |
| 93 | + draws: int, |
| 94 | + tune: int = 1000, |
| 95 | + chains: int = 4, |
| 96 | + random_seed=None, |
| 97 | + initvals=None, |
| 98 | + progressbar=True, |
| 99 | + cores=None, |
| 100 | + idata_kwargs=None, |
| 101 | + compute_convergence_checks=True, |
| 102 | + **kwargs, |
| 103 | + ) -> InferenceData: |
| 104 | + """Run NutPie sampler and return results as InferenceData. |
| 105 | +
|
| 106 | + Parameters |
| 107 | + ---------- |
| 108 | + draws : int |
| 109 | + Number of draws per chain |
| 110 | + tune : int |
| 111 | + Number of tuning draws per chain |
| 112 | + chains : int |
| 113 | + Number of chains to sample |
| 114 | + random_seed : int or sequence, optional |
| 115 | + Random seed(s) for reproducibility |
| 116 | + initvals : dict or list of dict, optional |
| 117 | + Initial values for variables (currently not used by NutPie) |
| 118 | + progressbar : bool |
| 119 | + Whether to display progress bar |
| 120 | + cores : int, optional |
| 121 | + Number of CPU cores to use (ignored by NutPie) |
| 122 | + idata_kwargs : dict, optional |
| 123 | + Additional arguments for arviz.InferenceData conversion |
| 124 | + compute_convergence_checks : bool |
| 125 | + Whether to compute convergence diagnostics |
| 126 | + **kwargs |
| 127 | + Additional sampler-specific parameters |
| 128 | +
|
| 129 | + Returns |
| 130 | + ------- |
| 131 | + InferenceData |
| 132 | + ArviZ InferenceData object with sampling results |
| 133 | + """ |
| 134 | + model = kwargs.pop("model", self.model) |
| 135 | + if model is None: |
| 136 | + model = Model.get_context() |
| 137 | + |
| 138 | + # Handle variables |
| 139 | + vars = kwargs.pop("vars", self._vars) |
| 140 | + if vars is None: |
| 141 | + vars = model.value_vars |
| 142 | + |
| 143 | + # Create a NutPie model |
| 144 | + logger.info("Compiling NutPie model") |
| 145 | + nutpie_model = nutpie.compile_pymc_model( |
| 146 | + model, |
| 147 | + backend=self.backend, |
| 148 | + ) |
| 149 | + |
| 150 | + # Set up sampling parameters - NutPie does this internally |
| 151 | + # Keep these for other nutpie parameters to pass |
| 152 | + nuts_kwargs = { |
| 153 | + **self.nutpie_kwargs, |
| 154 | + **kwargs, |
| 155 | + } |
| 156 | + |
| 157 | + if initvals is not None: |
| 158 | + warnings.warn( |
| 159 | + "`initvals` are currently not passed to nutpie sampler. " |
| 160 | + "Use `init_mean` kwarg following nutpie specification instead.", |
| 161 | + UserWarning, |
| 162 | + ) |
| 163 | + |
| 164 | + # Set up random seed |
| 165 | + if random_seed is not None: |
| 166 | + nuts_kwargs["seed"] = random_seed |
| 167 | + |
| 168 | + # Run the sampler |
| 169 | + logger.info( |
| 170 | + f"Running NutPie sampler with {chains} chains, {tune} tuning steps, and {draws} draws" |
| 171 | + ) |
| 172 | + |
| 173 | + # Add target acceptance and max tree depth |
| 174 | + nutpie_kwargs = { |
| 175 | + "target_accept": self.target_accept, |
| 176 | + "maxdepth": self.max_treedepth, |
| 177 | + **nuts_kwargs, |
| 178 | + } |
| 179 | + |
| 180 | + # Update parameter names to match NutPie's API |
| 181 | + if "progressbar" in nutpie_kwargs: |
| 182 | + nutpie_kwargs["progress_bar"] = nutpie_kwargs.pop("progressbar") |
| 183 | + |
| 184 | + # Pass progressbar from the sample function arguments |
| 185 | + if progressbar is not None: |
| 186 | + nutpie_kwargs["progress_bar"] = progressbar |
| 187 | + |
| 188 | + # Call NutPie's sample function |
| 189 | + nutpie_trace = nutpie.sample( |
| 190 | + nutpie_model, |
| 191 | + draws=draws, |
| 192 | + tune=tune, |
| 193 | + chains=chains, |
| 194 | + **nutpie_kwargs, |
| 195 | + ) |
| 196 | + |
| 197 | + # Convert to InferenceData |
| 198 | + if idata_kwargs is None: |
| 199 | + idata_kwargs = {} |
| 200 | + |
| 201 | + # Extract relevant variables and data for InferenceData |
| 202 | + coords, dims = coords_and_dims_for_inferencedata(model) |
| 203 | + constants_data = find_constants(model) |
| 204 | + observed_data = find_observations(model) |
| 205 | + |
| 206 | + # Always include sampler stats |
| 207 | + if "include_sampler_stats" not in idata_kwargs: |
| 208 | + idata_kwargs["include_sampler_stats"] = True |
| 209 | + |
| 210 | + # NutPie already returns an InferenceData object |
| 211 | + idata = nutpie_trace |
| 212 | + |
| 213 | + # Set tuning steps attribute if possible |
| 214 | + try: |
| 215 | + idata.posterior.attrs["tuning_steps"] = tune |
| 216 | + except (AttributeError, KeyError): |
| 217 | + logger.warning("Could not set tuning_steps attribute on InferenceData") |
| 218 | + |
| 219 | + # Skip compute_convergence_checks for now |
| 220 | + # NutPie's InferenceData structure is different from PyMC's expectations |
| 221 | + |
| 222 | + return idata |
| 223 | + |
| 224 | + @staticmethod |
| 225 | + def competence(var, has_grad): |
| 226 | + """Determine competence level for sampling var. |
| 227 | +
|
| 228 | + Parameters |
| 229 | + ---------- |
| 230 | + var : Variable |
| 231 | + Variable to be sampled |
| 232 | + has_grad : bool |
| 233 | + Whether gradient information is available |
| 234 | +
|
| 235 | + Returns |
| 236 | + ------- |
| 237 | + Competence |
| 238 | + Enum indicating competence level for this variable |
| 239 | + """ |
| 240 | + if var.dtype in continuous_types and has_grad: |
| 241 | + return Competence.IDEAL |
| 242 | + return Competence.INCOMPATIBLE |
0 commit comments