Skip to content

Commit 9b18cd3

Browse files
committed
FIX: Better testing
1 parent 8b3aaa9 commit 9b18cd3

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

nibabel/tests/test_viewers.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@
1010
import numpy as np
1111
from collections import namedtuple as nt
1212

13-
try:
14-
import matplotlib
15-
matplotlib.use('agg')
16-
except Exception:
17-
pass
1813

1914
from ..optpkg import optional_package
2015
from ..viewers import OrthoSlicer3D
@@ -24,13 +19,16 @@
2419

2520
from nose.tools import assert_raises
2621

27-
plt, has_mpl = optional_package('matplotlib.pyplot')[:2]
22+
matplotlib, has_mpl = optional_package('matplotlib')[:2]
2823
needs_mpl = skipif(not has_mpl, 'These tests need matplotlib')
24+
if has_mpl:
25+
matplotlib.use('Agg')
2926

3027

3128
@needs_mpl
3229
def test_viewer():
3330
# Test viewer
31+
plt = optional_package('matplotlib.pyplot')[0]
3432
a = np.sin(np.linspace(0, np.pi, 20))
3533
b = np.sin(np.linspace(0, np.pi*5, 30))
3634
data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis]

nibabel/viewers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
from .optpkg import optional_package
1212
from .orientations import aff2axcodes, axcodes2ornt
1313

14-
plt, _, _ = optional_package('matplotlib.pyplot')
15-
mpl_img, _, _ = optional_package('matplotlib.image')
16-
mpl_patch, _, _ = optional_package('matplotlib.patches')
17-
1814

1915
class OrthoSlicer3D(object):
2016
"""Orthogonal-plane slicer.
@@ -59,6 +55,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
5955
figsize : tuple
6056
Figure size (in inches) to use if axes are None.
6157
"""
58+
# Nest imports so that matplotlib.use() has the appropriate
59+
# effect in testing
60+
plt, _, _ = optional_package('matplotlib.pyplot')
61+
mpl_img, _, _ = optional_package('matplotlib.image')
62+
mpl_patch, _, _ = optional_package('matplotlib.patches')
63+
6264
data = np.asanyarray(data)
6365
if data.ndim < 3:
6466
raise ValueError('data must have at least 3 dimensions')
@@ -200,11 +202,13 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
200202
def show(self):
201203
"""Show the slicer in blocking mode; convenience for ``plt.show()``
202204
"""
205+
plt, _, _ = optional_package('matplotlib.pyplot')
203206
plt.show()
204207

205208
def close(self):
206209
"""Close the viewer figures
207210
"""
211+
plt, _, _ = optional_package('matplotlib.pyplot')
208212
for f in self._figs:
209213
plt.close(f)
210214
for link in self._links:

0 commit comments

Comments
 (0)