Skip to content

Commit ed74a61

Browse files
aseyboldtJunpeng Lao
authored andcommitted
Fix max_treedepth warning (#2808)
* Fix max_treedepth warning * Less sensitive neff warning
1 parent 1e2a0c3 commit ed74a61

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

pymc3/backends/report.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,17 @@ def _run_convergence_checks(self, trace):
102102
warn = SamplerWarning(
103103
WarningType.CONVERGENCE, msg, 'error', None, None, effective_n)
104104
warnings.append(warn)
105+
elif eff_min / n_samples < 0.1:
106+
msg = ("The number of effective samples is smaller than "
107+
"10% for some parameters.")
108+
warn = SamplerWarning(
109+
WarningType.CONVERGENCE, msg, 'warn', None, None, effective_n)
110+
warnings.append(warn)
105111
elif eff_min / n_samples < 0.25:
106112
msg = ("The number of effective samples is smaller than "
107113
"25% for some parameters.")
108114
warn = SamplerWarning(
109-
WarningType.CONVERGENCE, msg, 'warn', None, None, effective_n)
115+
WarningType.CONVERGENCE, msg, 'info', None, None, effective_n)
110116
warnings.append(warn)
111117

112118
self._add_warnings(warnings)

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,9 @@ def warnings(self, strace):
177177
WarningType.DIVERGENCES, msg, 'error', None, None, None)
178178
warnings.append(warning)
179179
elif n_divs > 0:
180-
message = ('Divergences after tuning. Increase `target_accept` or '
181-
'reparameterize.')
180+
message = ('There were %s divergences after tuning. Increase '
181+
'`target_accept` or reparameterize.'
182+
% n_divs)
182183
warning = SamplerWarning(
183184
WarningType.DIVERGENCES, message, 'error', None, None, None)
184185
warnings.append(warning)

pymc3/step_methods/hmc/nuts.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def _hamiltonian_step(self, start, p0, step_size):
170170
if divergence_info or turning:
171171
break
172172
else:
173-
self._reached_max_treedepth += 1
173+
if not self.tune:
174+
self._reached_max_treedepth += 1
174175

175176
stats = tree.stats()
176177
accept_stat = stats['mean_tree_accept']
@@ -185,8 +186,10 @@ def competence(var, has_grad):
185186

186187
def warnings(self, strace):
187188
warnings = super(NUTS, self).warnings(strace)
189+
n_samples = self._samples_after_tune
190+
n_treedepth = self._reached_max_treedepth
188191

189-
if np.mean(self._reached_max_treedepth) > 0.05:
192+
if n_samples > 0 and n_treedepth / float(n_samples) > 0.05:
190193
msg = ('The chain reached the maximum tree depth. Increase '
191194
'max_treedepth, increase target_accept or reparameterize.')
192195
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn',

0 commit comments

Comments
 (0)