Skip to content

Commit 00760bb

Browse files
committed
Refactor and correct divergence warning
1 parent fc2bb56 commit 00760bb

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,19 @@ def warnings(self):
171171
# Generate a global warning for divergences
172172
n_divs = self._num_divs_sample
173173
if n_divs and self._samples_after_tune == n_divs:
174-
msg = ('The chain contains only diverging samples. The model is '
175-
'probably misspecified.')
176-
warning = SamplerWarning(
177-
WarningType.DIVERGENCES, msg, 'error', None, None, None)
178-
warnings.append(warning)
174+
message = ('The chain contains only diverging samples. The model '
175+
'is probably misspecified.')
176+
elif n_divs == 1:
177+
message = ('There was 1 divergence after tuning. Increase '
178+
'`target_accept` or reparameterize.')
179179
elif n_divs > 0:
180180
message = ('There were %s divergences after tuning. Increase '
181181
'`target_accept` or reparameterize.'
182182
% n_divs)
183-
warning = SamplerWarning(
184-
WarningType.DIVERGENCES, message, 'error', None, None, None)
185-
warnings.append(warning)
186183

184+
warning = SamplerWarning(
185+
WarningType.DIVERGENCES, message, 'error', None, None, None)
186+
warnings.append(warning)
187187
warnings.extend(self.step_adapt.warnings())
188+
188189
return warnings

0 commit comments

Comments
 (0)