Skip to content

Commit 8e0d7ac

Browse files
authored
Merge pull request #3106 from eigenfoo/fix_div_warning
Fix grammar in divergence warning; refactor code
2 parents fc2bb56 + 89558c2 commit 8e0d7ac

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
### Fixes
3333

34+
- Fixed grammar in divergence warning, previously `There were 1 divergences ...` could be raised.
3435
- Fixed `KeyError` raised when only subset of variables are specified to be recorded in the trace.
3536
- Removed unused `repeat=None` arguments from all `random()` methods in distributions.
3637
- Deprecated the `sigma` argument in `MarginalSparse.marginal_likelihood` in favor of `noise`

pymc3/step_methods/hmc/base_hmc.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,19 @@ def warnings(self):
169169
warnings = self._warnings[:]
170170

171171
# Generate a global warning for divergences
172+
message = ''
172173
n_divs = self._num_divs_sample
173174
if n_divs and self._samples_after_tune == n_divs:
174-
msg = ('The chain contains only diverging samples. The model is '
175-
'probably misspecified.')
176-
warning = SamplerWarning(
177-
WarningType.DIVERGENCES, msg, 'error', None, None, None)
178-
warnings.append(warning)
179-
elif n_divs > 0:
175+
message = ('The chain contains only diverging samples. The model '
176+
'is probably misspecified.')
177+
elif n_divs == 1:
178+
message = ('There was 1 divergence after tuning. Increase '
179+
'`target_accept` or reparameterize.')
180+
elif n_divs > 1:
180181
message = ('There were %s divergences after tuning. Increase '
181-
'`target_accept` or reparameterize.'
182-
% n_divs)
182+
'`target_accept` or reparameterize.' % n_divs)
183+
184+
if message:
183185
warning = SamplerWarning(
184186
WarningType.DIVERGENCES, message, 'error', None, None, None)
185187
warnings.append(warning)

pymc3/tests/test_step.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ def test_linalg(self, caplog):
446446
warns = [msg.msg for msg in caplog.records]
447447
assert np.any(trace['diverging'])
448448
assert (
449+
any('divergence after tuning' in warn
450+
for warn in warns)
451+
or
449452
any('divergences after tuning' in warn
450453
for warn in warns)
451454
or

0 commit comments

Comments
 (0)