Skip to content

Commit 41d70b8

Browse files
authored
Fix bug in nuts stats (#2467)
1 parent 259613f commit 41d70b8

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc3/step_methods/hmc/nuts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ def astep(self, q0):
197197

198198
for _ in range(max_treedepth):
199199
direction = logbern(np.log(0.5)) * 2 - 1
200-
diverging, turning = tree.extend(direction)
200+
diverging_info, turning = tree.extend(direction)
201201
q, q_grad = tree.proposal.q, tree.proposal.q_grad
202202

203-
if diverging or turning:
204-
if diverging:
205-
self.report._add_divergence(self.tune, *diverging)
203+
if diverging_info or turning:
204+
if diverging_info:
205+
self.report._add_divergence(self.tune, *diverging_info)
206206
break
207207

208208
w = 1. / (self.m + self.t0)
@@ -223,7 +223,7 @@ def astep(self, q0):
223223
'step_size': step_size,
224224
'tune': self.tune,
225225
'step_size_bar': np.exp(self.log_step_size_bar),
226-
'diverging': diverging,
226+
'diverging': bool(diverging_info),
227227
}
228228

229229
stats.update(tree.stats())

0 commit comments

Comments
 (0)