Skip to content

Commit 64a092c

Browse files
Don't store total divergences in NUTS stats
1 parent d5d9516 commit 64a092c

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

pymc/step_methods/hmc/base_hmc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
275275
stats: dict[str, Any] = {
276276
"tune": self.tune,
277277
"diverging": diverging,
278-
"divergences": self.divergences,
279278
"perf_counter_diff": perf_end - perf_start,
280279
"process_time_diff": process_end - process_start,
281280
"perf_counter_start": perf_start,

pymc/step_methods/hmc/nuts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ class NUTS(BaseHMC):
115115
"step_size_bar": (np.float64, []),
116116
"tree_size": (np.float64, []),
117117
"diverging": (bool, []),
118-
"divergences": (np.int64, []),
119118
"energy_error": (np.float64, []),
120119
"energy": (np.float64, []),
121120
"max_energy_error": (np.float64, []),
@@ -248,11 +247,12 @@ def _progressbar_config(n_chains=1):
248247

249248
return columns, stats
250249

251-
@staticmethod
252-
def _make_progressbar_update_functions():
250+
def _make_progressbar_update_functions(self):
253251
def update_stats(stats):
254-
return {key: stats[key] for key in ("divergences", "step_size", "tree_size")} | {
255-
"failing": stats["divergences"] > 0
252+
divergences = self.divergences
253+
return {key: stats[key] for key in ("step_size", "tree_size")} | {
254+
"failing": divergences > 0,
255+
"divergences": divergences,
256256
}
257257

258258
return (update_stats,)

0 commit comments

Comments
 (0)