Skip to content

Commit eccfa39

Browse files
Assign each step method a unique id, and track ids via stats emitted by Step.step
1 parent a8af2e8 commit eccfa39

File tree

8 files changed

+92
-40
lines changed

8 files changed

+92
-40
lines changed

pymc/backends/ndarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def record(self, point, sampler_stats=None) -> None:
113113
if sampler_stats is not None:
114114
for data, vars in zip(self._stats, sampler_stats):
115115
for key, val in vars.items():
116+
# step_meta is a key used by the progress bars to track which draw came from which step instance. It
117+
# should never be stored as a sampler statistic.
118+
if key == "step_meta":
119+
continue
116120
data[key][draw_idx] = val
117121
elif self._stats is not None:
118122
raise ValueError("Expected sampler_stats")

pymc/sampling/mcmc.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Functions for MCMC sampling."""
1616

1717
import contextlib
18+
import itertools
1819
import logging
1920
import pickle
2021
import sys
@@ -111,6 +112,7 @@ def instantiate_steppers(
111112
step_kwargs: dict[str, dict] | None = None,
112113
initial_point: PointType | None = None,
113114
compile_kwargs: dict | None = None,
115+
step_id_generator: Iterator[int] | None = None,
114116
) -> Step | list[Step]:
115117
"""Instantiate steppers assigned to the model variables.
116118
@@ -139,6 +141,9 @@ def instantiate_steppers(
139141
if step_kwargs is None:
140142
step_kwargs = {}
141143

144+
if step_id_generator is None:
145+
step_id_generator = itertools.count()
146+
142147
used_keys = set()
143148
if selected_steps:
144149
if initial_point is None:
@@ -154,6 +159,7 @@ def instantiate_steppers(
154159
model=model,
155160
initial_point=initial_point,
156161
compile_kwargs=compile_kwargs,
162+
step_id_generator=step_id_generator,
157163
**kwargs,
158164
)
159165
steps.append(step)
@@ -853,16 +859,19 @@ def joined_blas_limiter():
853859
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]
854860

855861
# Instantiate automatically selected steps
862+
# Use a counter to generate a unique id for each stepper used in the model.
863+
step_id_generator = itertools.count()
856864
step = instantiate_steppers(
857865
model,
858866
steps=provided_steps,
859867
selected_steps=selected_steps,
860868
step_kwargs=kwargs,
861869
initial_point=initial_points[0],
862870
compile_kwargs=compile_kwargs,
871+
step_id_generator=step_id_generator,
863872
)
864873
if isinstance(step, list):
865-
step = CompoundStep(step)
874+
step = CompoundStep(step, step_id_generator=step_id_generator)
866875

867876
if var_names is not None:
868877
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]

pymc/step_methods/arraystep.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import abstractmethod
16-
from collections.abc import Callable
16+
from collections.abc import Callable, Iterator
1717
from typing import cast
1818

1919
import numpy as np
@@ -43,14 +43,25 @@ class ArrayStep(BlockedStep):
4343
:py:func:`pymc.util.get_random_generator` for more information.
4444
"""
4545

46-
def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None):
46+
def __init__(
47+
self,
48+
vars,
49+
fs,
50+
allvars=False,
51+
blocked=True,
52+
rng: RandomGenerator = None,
53+
step_id_generator: Iterator[int] | None = None,
54+
):
4755
self.vars = vars
4856
self.fs = fs
4957
self.allvars = allvars
5058
self.blocked = blocked
5159
self.rng = get_random_generator(rng)
60+
self._step_id = next(step_id_generator) if step_id_generator else None
5261

53-
def step(self, point: PointType) -> tuple[PointType, StatsType]:
62+
def step(
63+
self, point: PointType, step_parent_id: int | None = None
64+
) -> tuple[PointType, StatsType]:
5465
partial_funcs_and_point: list[Callable | PointType] = [
5566
DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
5667
]
@@ -61,6 +72,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
6172
apoint = DictToArrayBijection.map(var_dict)
6273
apoint_new, stats = self.astep(apoint, *partial_funcs_and_point)
6374

75+
for sts in stats:
76+
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}
77+
6478
if not isinstance(apoint_new, RaveledVars):
6579
# We assume that the mapping has stayed the same
6680
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)
@@ -84,7 +98,14 @@ class ArrayStepShared(BlockedStep):
8498
and unmapping overhead as well as moving fewer variables around.
8599
"""
86100

87-
def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
101+
def __init__(
102+
self,
103+
vars,
104+
shared,
105+
blocked=True,
106+
rng: RandomGenerator = None,
107+
step_id_generator: Iterator[int] | None = None,
108+
):
88109
"""
89110
Create the ArrayStepShared object.
90111
@@ -103,8 +124,11 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
103124
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
104125
self.blocked = blocked
105126
self.rng = get_random_generator(rng)
127+
self._step_id = next(step_id_generator) if step_id_generator else None
106128

107-
def step(self, point: PointType) -> tuple[PointType, StatsType]:
129+
def step(
130+
self, point: PointType, step_parent_id: int | None = None
131+
) -> tuple[PointType, StatsType]:
108132
full_point = None
109133
if self.shared:
110134
for name, shared_var in self.shared.items():
@@ -115,6 +139,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
115139
q = DictToArrayBijection.map(point)
116140
apoint, stats = self.astep(q)
117141

142+
for sts in stats:
143+
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}
144+
118145
if not isinstance(apoint, RaveledVars):
119146
# We assume that the mapping has stayed the same
120147
apoint = RaveledVars(apoint, q.point_map_info)

pymc/step_methods/compound.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import warnings
2222

2323
from abc import ABC, abstractmethod
24-
from collections.abc import Callable, Iterable, Mapping, Sequence
24+
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
2525
from dataclasses import field
2626
from enum import IntEnum, unique
2727
from typing import Any
@@ -125,6 +125,8 @@ class BlockedStep(ABC, WithSamplingState):
125125

126126
def __new__(cls, *args, **kwargs):
127127
blocked = kwargs.get("blocked")
128+
step_id_generator = kwargs.pop("step_id_generator", None)
129+
128130
if blocked is None:
129131
# Try to look up default value from class
130132
blocked = getattr(cls, "default_blocked", True)
@@ -168,16 +170,19 @@ def __new__(cls, *args, **kwargs):
168170
# call __init__
169171
_kwargs = kwargs.copy()
170172
_kwargs["rng"] = rng
173+
_kwargs["step_id_generator"] = step_id_generator
171174
step.__init__([var], *args, **_kwargs)
172175
# Hack for creating the class correctly when unpickling.
173176
step.__newargs = ([var], *args), _kwargs
174177
steps.append(step)
175178

176-
return CompoundStep(steps)
179+
return CompoundStep(steps, step_id_generator=step_id_generator)
177180
else:
178181
step = super().__new__(cls)
179182
step.stats_dtypes = stats_dtypes
180183
step.stats_dtypes_shapes = stats_dtypes_shapes
184+
step._step_id = next(step_id_generator) if step_id_generator else None
185+
181186
# Hack for creating the class correctly when unpickling.
182187
step.__newargs = (vars, *args), kwargs
183188
return step
@@ -223,7 +228,7 @@ def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
223228

224229
def update_stats(
225230
displayed_stats: dict[str, np.ndarray],
226-
step_stats: dict[str, str | float | int | bool | None],
231+
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
227232
chain_idx: int,
228233
) -> dict[str, np.ndarray]:
229234
"""
@@ -235,7 +240,7 @@ def update_stats(
235240
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
236241
the values are the current values of the statistics, with one value per chain being sampled.
237242
238-
step_stats: dict
243+
step_stats_dict: dict
239244
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
240245
names of the statistics and the values are the values of the statistics generated by the step sampler.
241246
@@ -256,7 +261,9 @@ def __getnewargs_ex__(self):
256261
return self.__newargs
257262

258263
@abstractmethod
259-
def step(self, point: PointType) -> tuple[PointType, StatsType]:
264+
def step(
265+
self, point: PointType, step_parent_id: int | None = None
266+
) -> tuple[PointType, StatsType]:
260267
"""Perform a single step of the sampler."""
261268

262269
@staticmethod
@@ -315,7 +322,7 @@ class CompoundStep(WithSamplingState):
315322

316323
_state_class = CompoundStepState
317324

318-
def __init__(self, methods):
325+
def __init__(self, methods, step_id_generator: Iterator[int] | None = None):
319326
self.methods = list(methods)
320327
self.stats_dtypes = []
321328
for method in self.methods:
@@ -325,11 +332,12 @@ def __init__(self, methods):
325332
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
326333
)
327334
self.tune = True
335+
self._step_id = next(step_id_generator) if step_id_generator else None
328336

329-
def step(self, point) -> tuple[PointType, StatsType]:
337+
def step(self, point, step_parent_id: int | None = None) -> tuple[PointType, StatsType]:
330338
stats = []
331339
for method in self.methods:
332-
point, sts = method.step(point)
340+
point, sts = method.step(point, step_parent_id=self._step_id)
333341
stats.extend(sts)
334342
# Model logp can only be the logp of the _last_ stats,
335343
# if there is one. Pop all others.
@@ -409,7 +417,7 @@ def _progressbar_config(
409417

410418
return columns, stats
411419

412-
def _make_update_stats_function(self) -> Callable[[dict, list[dict], int], dict]:
420+
def _make_update_stats_function(self) -> Callable[[dict, dict[int, dict], int], dict]:
413421
"""
414422
Create an update function used by the progress bar to update statistics during sampling.
415423
@@ -419,11 +427,13 @@ def _make_update_stats_function(self) -> Callable[[dict, list[dict], int], dict]
419427
Function that updates displayed statistics for the current chain, given statistics generated by the step
420428
during the most recent step.
421429
"""
422-
update_fns = [method._make_update_stats_function() for method in self.methods]
430+
update_fns = {
431+
method._step_id: method._make_update_stats_function() for method in self.methods
432+
}
423433

424434
def update_stats(
425435
displayed_stats: dict[str, np.ndarray],
426-
step_stats: list[dict[str, str | float | int | bool | None]],
436+
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
427437
chain_idx: int,
428438
) -> dict[str, np.ndarray]:
429439
"""
@@ -435,7 +445,7 @@ def update_stats(
435445
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
436446
the values are the current values of the statistics, with one value per chain being sampled.
437447
438-
step_stats: list of dict
448+
step_stats_dict: dict of dict
439449
List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when
440450
taking the current step. For each dictionary, the keys are names of statistics and the values are
441451
the values of the statistics generated by the step sampler.
@@ -452,11 +462,9 @@ def update_stats(
452462
# In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the
453463
# next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate
454464
# the stats from each step.
455-
if not isinstance(step_stats, list):
456-
step_stats = [step_stats]
457465

458-
for step_stat, update_fn in zip(step_stats, update_fns):
459-
displayed_stats = update_fn(displayed_stats, step_stat, chain_idx)
466+
for step_id, update_fn in update_fns.items():
467+
displayed_stats = update_fn(displayed_stats, step_stats_dict, chain_idx)
460468

461469
return displayed_stats
462470

pymc/step_methods/hmc/base_hmc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import time
1919

2020
from abc import abstractmethod
21+
from collections.abc import Iterator
2122
from typing import Any, NamedTuple
2223

2324
import numpy as np
@@ -99,6 +100,7 @@ def __init__(
99100
step_rand=None,
100101
rng=None,
101102
initial_point: PointType | None = None,
103+
step_id_generator: Iterator[int] | None = None,
102104
**pytensor_kwargs,
103105
):
104106
"""Set up Hamiltonian samplers with common structures.
@@ -133,6 +135,7 @@ def __init__(
133135
**pytensor_kwargs: passed to PyTensor functions
134136
"""
135137
self._model = modelcontext(model)
138+
self._step_id = next(step_id_generator) if step_id_generator else None
136139

137140
if vars is None:
138141
vars = self._model.continuous_value_vars

pymc/step_methods/hmc/nuts.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
278278

279279
def update_stats(
280280
displayed_stats: dict[str, np.ndarray],
281-
step_stats: dict[str, str | float | int | bool | None],
281+
step_stats_dict: dict[int, dict[str, str | float | int | bool | None]],
282282
chain_idx: int,
283283
) -> dict[str, np.ndarray]:
284284
"""
@@ -290,7 +290,7 @@ def update_stats(
290290
Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
291291
the values are the current values of the statistics, with one value per chain being sampled.
292292
293-
step_stats: dict
293+
step_stats_dict: dict
294294
Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
295295
names of the statistics and the values are the values of the statistics generated by the step sampler.
296296
@@ -302,8 +302,7 @@ def update_stats(
302302
dict
303303
The updated statistics dictionary to be displayed in the progress bar.
304304
"""
305-
if isinstance(step_stats, list):
306-
step_stats = step_stats[0]
305+
step_stats = step_stats_dict[self._step_id]
307306

308307
if not step_stats["tune"]:
309308
displayed_stats["divergences"][chain_idx] += step_stats["diverging"]

0 commit comments

Comments
 (0)