2121import warnings
2222
2323from abc import ABC , abstractmethod
24- from collections .abc import Callable , Iterable , Mapping , Sequence
24+ from collections .abc import Callable , Iterable , Iterator , Mapping , Sequence
2525from dataclasses import field
2626from enum import IntEnum , unique
2727from typing import Any
@@ -125,6 +125,8 @@ class BlockedStep(ABC, WithSamplingState):
125125
126126 def __new__ (cls , * args , ** kwargs ):
127127 blocked = kwargs .get ("blocked" )
128+ step_id_generator = kwargs .pop ("step_id_generator" , None )
129+
128130 if blocked is None :
129131 # Try to look up default value from class
130132 blocked = getattr (cls , "default_blocked" , True )
@@ -168,16 +170,19 @@ def __new__(cls, *args, **kwargs):
168170 # call __init__
169171 _kwargs = kwargs .copy ()
170172 _kwargs ["rng" ] = rng
173+ _kwargs ["step_id_generator" ] = step_id_generator
171174 step .__init__ ([var ], * args , ** _kwargs )
172175 # Hack for creating the class correctly when unpickling.
173176 step .__newargs = ([var ], * args ), _kwargs
174177 steps .append (step )
175178
176- return CompoundStep (steps )
179+ return CompoundStep (steps , step_id_generator = step_id_generator )
177180 else :
178181 step = super ().__new__ (cls )
179182 step .stats_dtypes = stats_dtypes
180183 step .stats_dtypes_shapes = stats_dtypes_shapes
184+ step ._step_id = next (step_id_generator ) if step_id_generator else None
185+
181186 # Hack for creating the class correctly when unpickling.
182187 step .__newargs = (vars , * args ), kwargs
183188 return step
@@ -223,7 +228,7 @@ def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]:
223228
224229 def update_stats (
225230 displayed_stats : dict [str , np .ndarray ],
226- step_stats : dict [str , str | float | int | bool | None ],
231+ step_stats_dict : dict [int , dict [ str , str | float | int | bool | None ] ],
227232 chain_idx : int ,
228233 ) -> dict [str , np .ndarray ]:
229234 """
@@ -235,7 +240,7 @@ def update_stats(
235240 Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
236241 the values are the current values of the statistics, with one value per chain being sampled.
237242
238- step_stats : dict
243+ step_stats_dict : dict
239244 Dictionary of statistics generated by the step sampler when taking the current step. The keys are the
240245 names of the statistics and the values are the values of the statistics generated by the step sampler.
241246
@@ -256,7 +261,9 @@ def __getnewargs_ex__(self):
256261 return self .__newargs
257262
258263 @abstractmethod
259- def step (self , point : PointType ) -> tuple [PointType , StatsType ]:
264+ def step (
265+ self , point : PointType , step_parent_id : int | None = None
266+ ) -> tuple [PointType , StatsType ]:
260267 """Perform a single step of the sampler."""
261268
262269 @staticmethod
@@ -315,7 +322,7 @@ class CompoundStep(WithSamplingState):
315322
316323 _state_class = CompoundStepState
317324
318- def __init__ (self , methods ):
325+ def __init__ (self , methods , step_id_generator : Iterator [ int ] | None = None ):
319326 self .methods = list (methods )
320327 self .stats_dtypes = []
321328 for method in self .methods :
@@ -325,11 +332,12 @@ def __init__(self, methods):
325332 f"Compound[{ ', ' .join (getattr (m , 'name' , 'UNNAMED_STEP' ) for m in self .methods )} ]"
326333 )
327334 self .tune = True
335+ self ._step_id = next (step_id_generator ) if step_id_generator else None
328336
329- def step (self , point ) -> tuple [PointType , StatsType ]:
337+ def step (self , point , step_parent_id : int | None = None ) -> tuple [PointType , StatsType ]:
330338 stats = []
331339 for method in self .methods :
332- point , sts = method .step (point )
340+ point , sts = method .step (point , step_parent_id = self . _step_id )
333341 stats .extend (sts )
334342 # Model logp can only be the logp of the _last_ stats,
335343 # if there is one. Pop all others.
@@ -409,7 +417,7 @@ def _progressbar_config(
409417
410418 return columns , stats
411419
412- def _make_update_stats_function (self ) -> Callable [[dict , list [ dict ], int ], dict ]:
420+ def _make_update_stats_function (self ) -> Callable [[dict , dict [ int , dict ], int ], dict ]:
413421 """
414422 Create an update function used by the progress bar to update statistics during sampling.
415423
@@ -419,11 +427,13 @@ def _make_update_stats_function(self) -> Callable[[dict, list[dict], int], dict]
419427 Function that updates displayed statistics for the current chain, given statistics generated by the step
420428 during the most recent step.
421429 """
422- update_fns = [method ._make_update_stats_function () for method in self .methods ]
430+ update_fns = {
431+ method ._step_id : method ._make_update_stats_function () for method in self .methods
432+ }
423433
424434 def update_stats (
425435 displayed_stats : dict [str , np .ndarray ],
426- step_stats : list [ dict [str , str | float | int | bool | None ]],
436+ step_stats_dict : dict [ int , dict [str , str | float | int | bool | None ]],
427437 chain_idx : int ,
428438 ) -> dict [str , np .ndarray ]:
429439 """
@@ -435,7 +445,7 @@ def update_stats(
435445 Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and
436446 the values are the current values of the statistics, with one value per chain being sampled.
437447
438- step_stats: list of dict
448+ step_stats_dict: dict of dict
439449 List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when
440450 taking the current step. For each dictionary, the keys are names of statistics and the values are
441451 the values of the statistics generated by the step sampler.
@@ -452,11 +462,9 @@ def update_stats(
452462 # In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the
453463 # next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate
454464 # the stats from each step.
455- if not isinstance (step_stats , list ):
456- step_stats = [step_stats ]
457465
458- for step_stat , update_fn in zip ( step_stats , update_fns ):
459- displayed_stats = update_fn (displayed_stats , step_stat , chain_idx )
466+ for step_id , update_fn in update_fns . items ( ):
467+ displayed_stats = update_fn (displayed_stats , step_stats_dict , chain_idx )
460468
461469 return displayed_stats
462470
0 commit comments