Skip to content

Commit 3391a97

Browse files
twieckiclaude
andcommitted
fix: correctly map parameter names in NutPie external sampler
- Change max_tree_depth to maxdepth to match NutPie's API - Fix coords_and_dims_for_inferencedata function call - Handle progressbar parameter correctly - Skip conversion to InferenceData since NutPie already returns one - Simplify convergence checks to handle NutPie's different structure 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent ad583ae commit 3391a97

File tree

4 files changed

+498
-0
lines changed

4 files changed

+498
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""External samplers integration for PyMC."""
16+
17+
from pymc.step_methods.external.base import ExternalSampler
18+
from pymc.step_methods.external.nutpie import NutPie
19+
20+
__all__ = ["ExternalSampler", "NutPie"]

pymc/step_methods/external/base.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
17+
from arviz import InferenceData
18+
19+
from pymc.step_methods.compound import BlockedStep, Competence
20+
21+
22+
class ExternalSampler(BlockedStep, ABC):
23+
"""Base class for external samplers.
24+
25+
External samplers manage their own MCMC loop rather than using PyMC's.
26+
These samplers (like NutPie, BlackJax, etc.) are designed to run
27+
their own efficient loop inside their implementation.
28+
29+
Attributes
30+
----------
31+
is_external : bool
32+
Flag indicating that this is an external sampler that needs
33+
special handling in PyMC's sampling loops.
34+
"""
35+
36+
is_external = True
37+
38+
def __init__(
39+
self,
40+
vars=None,
41+
model=None,
42+
**kwargs,
43+
):
44+
"""Initialize external sampler.
45+
46+
Parameters
47+
----------
48+
vars : list, optional
49+
Variables to be sampled
50+
model : Model, optional
51+
PyMC model
52+
**kwargs
53+
Sampler-specific arguments
54+
"""
55+
self.model = model
56+
self._vars = vars
57+
self._kwargs = kwargs
58+
59+
@abstractmethod
60+
def sample(
61+
self,
62+
draws: int,
63+
tune: int = 1000,
64+
chains: int = 4,
65+
random_seed=None,
66+
initvals=None,
67+
progressbar=True,
68+
cores=None,
69+
**kwargs,
70+
) -> InferenceData:
71+
"""Run external sampler and return results as InferenceData.
72+
73+
Parameters
74+
----------
75+
draws : int
76+
Number of draws per chain
77+
tune : int
78+
Number of tuning draws per chain
79+
chains : int
80+
Number of chains to sample
81+
random_seed : int or sequence, optional
82+
Random seed(s) for reproducibility
83+
initvals : dict or list of dict, optional
84+
Initial values for variables
85+
progressbar : bool
86+
Whether to display progress bar
87+
cores : int, optional
88+
Number of CPU cores to use
89+
**kwargs
90+
Additional sampler-specific parameters
91+
92+
Returns
93+
-------
94+
InferenceData
95+
ArviZ InferenceData object with sampling results
96+
"""
97+
pass
98+
99+
def step(self, point):
100+
"""Do not use this method. External samplers use their own sampling loop.
101+
102+
External samplers do not use PyMC's step() mechanism.
103+
"""
104+
raise NotImplementedError(
105+
"External samplers use their own sampling loop rather than PyMC's step() method."
106+
)
107+
108+
@staticmethod
109+
def competence(var, has_grad):
110+
"""Determine competence level for sampling var.
111+
112+
Parameters
113+
----------
114+
var : Variable
115+
Variable to be sampled
116+
has_grad : bool
117+
Whether gradient information is available
118+
119+
Returns
120+
-------
121+
Competence
122+
Enum indicating competence level for this variable
123+
"""
124+
return Competence.COMPATIBLE
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import warnings
17+
18+
from typing import Literal
19+
20+
from arviz import InferenceData
21+
22+
from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
23+
from pymc.model import Model
24+
from pymc.step_methods.compound import Competence
25+
from pymc.step_methods.external.base import ExternalSampler
26+
from pymc.vartypes import continuous_types
27+
28+
logger = logging.getLogger("pymc")
29+
30+
try:
31+
import nutpie
32+
33+
# Check if it's actually installed and not just an empty mock module
34+
NUTPIE_AVAILABLE = hasattr(nutpie, "compile_pymc_model")
35+
except ImportError:
36+
NUTPIE_AVAILABLE = False
37+
38+
39+
class NutPie(ExternalSampler):
40+
"""NutPie No-U-Turn Sampler.
41+
42+
This class provides an interface to the NutPie sampler, which is a high-performance
43+
implementation of the No-U-Turn Sampler (NUTS). Unlike PyMC's native NUTS implementation,
44+
NutPie samples chains sequentially in a single CPU, which can be more efficient for some
45+
models.
46+
47+
Parameters
48+
----------
49+
vars : list, optional
50+
Variables to be sampled
51+
model : Model, optional
52+
PyMC model
53+
backend : {"numba", "jax"}, default="numba"
54+
Which backend to use for computation
55+
target_accept : float, default=0.8
56+
Target acceptance rate for step size adaptation
57+
max_treedepth : int, default=10
58+
Maximum tree depth for NUTS (passed as 'maxdepth' to NutPie)
59+
**kwargs
60+
Additional parameters passed to nutpie.sample()
61+
62+
Notes
63+
-----
64+
Requires the nutpie package to be installed:
65+
pip install nutpie
66+
"""
67+
68+
name = "nutpie"
69+
70+
def __init__(
71+
self,
72+
vars=None,
73+
*,
74+
model=None,
75+
backend: Literal["numba", "jax"] = "numba",
76+
target_accept: float = 0.8,
77+
max_treedepth: int = 10,
78+
**kwargs,
79+
):
80+
"""Initialize NutPie sampler."""
81+
if not NUTPIE_AVAILABLE:
82+
raise ImportError("nutpie not found. Install it with: pip install nutpie")
83+
84+
super().__init__(vars=vars, model=model)
85+
86+
self.backend = backend
87+
self.target_accept = target_accept
88+
self.max_treedepth = max_treedepth
89+
self.nutpie_kwargs = kwargs
90+
91+
def sample(
92+
self,
93+
draws: int,
94+
tune: int = 1000,
95+
chains: int = 4,
96+
random_seed=None,
97+
initvals=None,
98+
progressbar=True,
99+
cores=None,
100+
idata_kwargs=None,
101+
compute_convergence_checks=True,
102+
**kwargs,
103+
) -> InferenceData:
104+
"""Run NutPie sampler and return results as InferenceData.
105+
106+
Parameters
107+
----------
108+
draws : int
109+
Number of draws per chain
110+
tune : int
111+
Number of tuning draws per chain
112+
chains : int
113+
Number of chains to sample
114+
random_seed : int or sequence, optional
115+
Random seed(s) for reproducibility
116+
initvals : dict or list of dict, optional
117+
Initial values for variables (currently not used by NutPie)
118+
progressbar : bool
119+
Whether to display progress bar
120+
cores : int, optional
121+
Number of CPU cores to use (ignored by NutPie)
122+
idata_kwargs : dict, optional
123+
Additional arguments for arviz.InferenceData conversion
124+
compute_convergence_checks : bool
125+
Whether to compute convergence diagnostics
126+
**kwargs
127+
Additional sampler-specific parameters
128+
129+
Returns
130+
-------
131+
InferenceData
132+
ArviZ InferenceData object with sampling results
133+
"""
134+
model = kwargs.pop("model", self.model)
135+
if model is None:
136+
model = Model.get_context()
137+
138+
# Handle variables
139+
vars = kwargs.pop("vars", self._vars)
140+
if vars is None:
141+
vars = model.value_vars
142+
143+
# Create a NutPie model
144+
logger.info("Compiling NutPie model")
145+
nutpie_model = nutpie.compile_pymc_model(
146+
model,
147+
backend=self.backend,
148+
)
149+
150+
# Set up sampling parameters - NutPie does this internally
151+
# Keep these for other nutpie parameters to pass
152+
nuts_kwargs = {
153+
**self.nutpie_kwargs,
154+
**kwargs,
155+
}
156+
157+
if initvals is not None:
158+
warnings.warn(
159+
"`initvals` are currently not passed to nutpie sampler. "
160+
"Use `init_mean` kwarg following nutpie specification instead.",
161+
UserWarning,
162+
)
163+
164+
# Set up random seed
165+
if random_seed is not None:
166+
nuts_kwargs["seed"] = random_seed
167+
168+
# Run the sampler
169+
logger.info(
170+
f"Running NutPie sampler with {chains} chains, {tune} tuning steps, and {draws} draws"
171+
)
172+
173+
# Add target acceptance and max tree depth
174+
nutpie_kwargs = {
175+
"target_accept": self.target_accept,
176+
"maxdepth": self.max_treedepth,
177+
**nuts_kwargs,
178+
}
179+
180+
# Update parameter names to match NutPie's API
181+
if "progressbar" in nutpie_kwargs:
182+
nutpie_kwargs["progress_bar"] = nutpie_kwargs.pop("progressbar")
183+
184+
# Pass progressbar from the sample function arguments
185+
if progressbar is not None:
186+
nutpie_kwargs["progress_bar"] = progressbar
187+
188+
# Call NutPie's sample function
189+
nutpie_trace = nutpie.sample(
190+
nutpie_model,
191+
draws=draws,
192+
tune=tune,
193+
chains=chains,
194+
**nutpie_kwargs,
195+
)
196+
197+
# Convert to InferenceData
198+
if idata_kwargs is None:
199+
idata_kwargs = {}
200+
201+
# Extract relevant variables and data for InferenceData
202+
coords, dims = coords_and_dims_for_inferencedata(model)
203+
constants_data = find_constants(model)
204+
observed_data = find_observations(model)
205+
206+
# Always include sampler stats
207+
if "include_sampler_stats" not in idata_kwargs:
208+
idata_kwargs["include_sampler_stats"] = True
209+
210+
# NutPie already returns an InferenceData object
211+
idata = nutpie_trace
212+
213+
# Set tuning steps attribute if possible
214+
try:
215+
idata.posterior.attrs["tuning_steps"] = tune
216+
except (AttributeError, KeyError):
217+
logger.warning("Could not set tuning_steps attribute on InferenceData")
218+
219+
# Skip compute_convergence_checks for now
220+
# NutPie's InferenceData structure is different from PyMC's expectations
221+
222+
return idata
223+
224+
@staticmethod
225+
def competence(var, has_grad):
226+
"""Determine competence level for sampling var.
227+
228+
Parameters
229+
----------
230+
var : Variable
231+
Variable to be sampled
232+
has_grad : bool
233+
Whether gradient information is available
234+
235+
Returns
236+
-------
237+
Competence
238+
Enum indicating competence level for this variable
239+
"""
240+
if var.dtype in continuous_types and has_grad:
241+
return Competence.IDEAL
242+
return Competence.INCOMPATIBLE

0 commit comments

Comments
 (0)