Skip to content

Commit 6a0e74d

Browse files
michaelosthegetwiecki
authored andcommitted
Refactor StatsBijection to expose information about object stats
1 parent 763a3ea commit 6a0e74d

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

pymc/step_methods/compound.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ def stop_tuning(self):
191191
self.tune = False
192192

193193

194+
def flat_statname(sampler_idx: int, sname: str) -> str:
195+
"""Get the flat-stats name for a samplers stat."""
196+
return f"sampler_{sampler_idx}__{sname}"
197+
198+
194199
def get_stats_dtypes_shapes_from_steps(
195200
steps: Iterable[BlockedStep],
196201
) -> Dict[str, Tuple[StatDtype, StatShape]]:
@@ -201,7 +206,7 @@ def get_stats_dtypes_shapes_from_steps(
201206
result = {}
202207
for s, step in enumerate(steps):
203208
for sname, (dtype, shape) in step.stats_dtypes_shapes.items():
204-
result[f"sampler_{s}__{sname}"] = (dtype, shape)
209+
result[flat_statname(s, sname)] = (dtype, shape)
205210
return result
206211

207212

@@ -262,10 +267,21 @@ class StatsBijection:
262267

263268
def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None:
264269
# Keep a list of flat vs. original stat names
265-
self._stat_groups: List[List[Tuple[str, str]]] = [
266-
[(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()]
267-
for s, names_dtypes in enumerate(sampler_stats_dtypes)
268-
]
270+
stat_groups = []
271+
for s, names_dtypes in enumerate(sampler_stats_dtypes):
272+
group = []
273+
for statname, dtype in names_dtypes.items():
274+
flatname = flat_statname(s, statname)
275+
is_obj = np.dtype(dtype) == np.dtype(object)
276+
group.append((flatname, statname, is_obj))
277+
stat_groups.append(group)
278+
self._stat_groups: List[List[Tuple[str, str, bool]]] = stat_groups
279+
self.object_stats = {
280+
fname: (s, sname)
281+
for s, group in enumerate(self._stat_groups)
282+
for fname, sname, is_obj in group
283+
if is_obj
284+
}
269285

270286
@property
271287
def n_samplers(self) -> int:
@@ -275,9 +291,10 @@ def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
275291
"""Combine stats dicts of multiple samplers into one dict."""
276292
stats_dict = {}
277293
for s, sts in enumerate(stats_list):
278-
for statname, sval in sts.items():
279-
sname = f"sampler_{s}__{statname}"
280-
stats_dict[sname] = sval
294+
for fname, sname, is_obj in self._stat_groups[s]:
295+
if sname not in sts:
296+
continue
297+
stats_dict[fname] = sts[sname]
281298
return stats_dict
282299

283300
def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
@@ -286,7 +303,11 @@ def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
286303
The ``stats_dict`` can be a subset of all sampler stats.
287304
"""
288305
stats_list = []
289-
for namemap in self._stat_groups:
290-
d = {statname: stats_dict[sname] for sname, statname in namemap if sname in stats_dict}
306+
for group in self._stat_groups:
307+
d = {}
308+
for fname, sname, is_obj in group:
309+
if fname not in stats_dict:
310+
continue
311+
d[sname] = stats_dict[fname]
291312
stats_list.append(d)
292313
return stats_list

tests/step_methods/test_compound.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,20 +164,22 @@ def test_flatten_steps(self):
164164
def test_stats_bijection(self):
165165
step_stats_dtypes = [
166166
{"a": float, "b": int},
167-
{"a": float, "c": int},
167+
{"a": float, "c": Warning},
168168
]
169169
bij = StatsBijection(step_stats_dtypes)
170+
assert bij.object_stats == {"sampler_1__c": (1, "c")}
170171
assert bij.n_samplers == 2
172+
w = Warning("hmm")
171173
stats_l = [
172174
dict(a=1.5, b=3),
173-
dict(a=2.5, c=4),
175+
dict(a=2.5, c=w),
174176
]
175177
stats_d = bij.map(stats_l)
176178
assert isinstance(stats_d, dict)
177179
assert stats_d["sampler_0__a"] == 1.5
178180
assert stats_d["sampler_0__b"] == 3
179181
assert stats_d["sampler_1__a"] == 2.5
180-
assert stats_d["sampler_1__c"] == 4
182+
assert stats_d["sampler_1__c"] == w
181183
rev = bij.rmap(stats_d)
182184
assert isinstance(rev, list)
183185
assert len(rev) == len(stats_l)

0 commit comments

Comments
 (0)