diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index 23f05488b..fadc5b7e5 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -114,7 +114,13 @@ def record(self, point, sampler_stats=None) -> None: raise ValueError("Unknown sampler_stats") if sampler_stats is not None: for data, vars in zip(self._stats, sampler_stats): - for key, val in vars.items(): + compressed_vars = {} + for k, v in vars.items(): + if isinstance(v, np.ndarray) and v.shape[0] == 1 and len(v.shape) == 1: + compressed_vars[k] = v.reshape(1, 1) + else: + compressed_vars[k] = v + for key, val in compressed_vars.items(): data[key][self.draw_idx] = val self.draw_idx += 1