|
18 | 18 | @author: johnsalvatier
|
19 | 19 | """
|
20 | 20 |
|
21 |
| - |
22 | 21 | from abc import ABC, abstractmethod
|
23 | 22 | from enum import IntEnum, unique
|
24 |
| -from typing import Dict, List, Tuple |
| 23 | +from typing import Dict, List, Sequence, Tuple, Union |
25 | 24 |
|
26 | 25 | import numpy as np
|
27 | 26 |
|
28 | 27 | from pytensor.graph.basic import Variable
|
29 | 28 |
|
30 |
| -from pymc.blocking import PointType, StatsType |
| 29 | +from pymc.blocking import PointType, StatsDict, StatsType |
31 | 30 | from pymc.model import modelcontext
|
32 | 31 |
|
33 | 32 | __all__ = ("Competence", "CompoundStep")
|
@@ -165,3 +164,43 @@ def reset_tuning(self):
|
165 | 164 | @property
|
166 | 165 | def vars(self):
|
167 | 166 | return [var for method in self.methods for var in method.vars]
|
| 167 | + |
| 168 | + |
| 169 | +def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]: |
| 170 | + """Flatten a hierarchy of step methods to a list.""" |
| 171 | + if isinstance(step, BlockedStep): |
| 172 | + return [step] |
| 173 | + steps = [] |
| 174 | + if not isinstance(step, CompoundStep): |
| 175 | + raise ValueError(f"Unexpected type of step method: {step}") |
| 176 | + for sm in step.methods: |
| 177 | + steps += flatten_steps(sm) |
| 178 | + return steps |
| 179 | + |
| 180 | + |
| 181 | +class StatsBijection: |
| 182 | + """Map between a `list` of stats to `dict` of stats.""" |
| 183 | + |
| 184 | + def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None: |
| 185 | + # Keep a list of flat vs. original stat names |
| 186 | + self._stat_groups: List[List[Tuple[str, str]]] = [ |
| 187 | + [(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()] |
| 188 | + for s, names_dtypes in enumerate(sampler_stats_dtypes) |
| 189 | + ] |
| 190 | + |
| 191 | + def map(self, stats_list: StatsType) -> StatsDict: |
| 192 | + """Combine stats dicts of multiple samplers into one dict.""" |
| 193 | + stats_dict = {} |
| 194 | + for s, sts in enumerate(stats_list): |
| 195 | + for statname, sval in sts.items(): |
| 196 | + sname = f"sampler_{s}__{statname}" |
| 197 | + stats_dict[sname] = sval |
| 198 | + return stats_dict |
| 199 | + |
| 200 | + def rmap(self, stats_dict: StatsDict) -> StatsType: |
| 201 | + """Split a global stats dict into a list of sampler-wise stats dicts.""" |
| 202 | + stats_list = [] |
| 203 | + for namemap in self._stat_groups: |
| 204 | + d = {statname: stats_dict[sname] for sname, statname in namemap} |
| 205 | + stats_list.append(d) |
| 206 | + return stats_list |
0 commit comments