@@ -191,6 +191,11 @@ def stop_tuning(self):
191
191
self .tune = False
192
192
193
193
194
+ def flat_statname (sampler_idx : int , sname : str ) -> str :
195
+ """Get the flat-stats name for a samplers stat."""
196
+ return f"sampler_{ sampler_idx } __{ sname } "
197
+
198
+
194
199
def get_stats_dtypes_shapes_from_steps (
195
200
steps : Iterable [BlockedStep ],
196
201
) -> Dict [str , Tuple [StatDtype , StatShape ]]:
@@ -201,7 +206,7 @@ def get_stats_dtypes_shapes_from_steps(
201
206
result = {}
202
207
for s , step in enumerate (steps ):
203
208
for sname , (dtype , shape ) in step .stats_dtypes_shapes .items ():
204
- result [f"sampler_ { s } __ { sname } " ] = (dtype , shape )
209
+ result [flat_statname ( s , sname ) ] = (dtype , shape )
205
210
return result
206
211
207
212
@@ -262,10 +267,21 @@ class StatsBijection:
262
267
263
268
def __init__ (self , sampler_stats_dtypes : Sequence [Mapping [str , type ]]) -> None :
264
269
# Keep a list of flat vs. original stat names
265
- self ._stat_groups : List [List [Tuple [str , str ]]] = [
266
- [(f"sampler_{ s } __{ statname } " , statname ) for statname , _ in names_dtypes .items ()]
267
- for s , names_dtypes in enumerate (sampler_stats_dtypes )
268
- ]
270
+ stat_groups = []
271
+ for s , names_dtypes in enumerate (sampler_stats_dtypes ):
272
+ group = []
273
+ for statname , dtype in names_dtypes .items ():
274
+ flatname = flat_statname (s , statname )
275
+ is_obj = np .dtype (dtype ) == np .dtype (object )
276
+ group .append ((flatname , statname , is_obj ))
277
+ stat_groups .append (group )
278
+ self ._stat_groups : List [List [Tuple [str , str , bool ]]] = stat_groups
279
+ self .object_stats = {
280
+ fname : (s , sname )
281
+ for s , group in enumerate (self ._stat_groups )
282
+ for fname , sname , is_obj in group
283
+ if is_obj
284
+ }
269
285
270
286
@property
271
287
def n_samplers (self ) -> int :
@@ -275,9 +291,10 @@ def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
275
291
"""Combine stats dicts of multiple samplers into one dict."""
276
292
stats_dict = {}
277
293
for s , sts in enumerate (stats_list ):
278
- for statname , sval in sts .items ():
279
- sname = f"sampler_{ s } __{ statname } "
280
- stats_dict [sname ] = sval
294
+ for fname , sname , is_obj in self ._stat_groups [s ]:
295
+ if sname not in sts :
296
+ continue
297
+ stats_dict [fname ] = sts [sname ]
281
298
return stats_dict
282
299
283
300
def rmap (self , stats_dict : Mapping [str , Any ]) -> StatsType :
@@ -286,7 +303,11 @@ def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
286
303
The ``stats_dict`` can be a subset of all sampler stats.
287
304
"""
288
305
stats_list = []
289
- for namemap in self ._stat_groups :
290
- d = {statname : stats_dict [sname ] for sname , statname in namemap if sname in stats_dict }
306
+ for group in self ._stat_groups :
307
+ d = {}
308
+ for fname , sname , is_obj in group :
309
+ if fname not in stats_dict :
310
+ continue
311
+ d [sname ] = stats_dict [fname ]
291
312
stats_list .append (d )
292
313
return stats_list
0 commit comments