Skip to content

Commit 4579de1

Browse files
committed
test: Make datadir a fixture that can skip a test
1 parent 789dd4d commit 4579de1

File tree

5 files changed

+36
-31
lines changed

5 files changed

+36
-31
lines changed

niworkflows/tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,16 @@
3030
from templateflow.api import get as get_template
3131

3232
from niworkflows.testing import data_env_canary, test_data_env
33-
from niworkflows.tests.data import load_test_data
3433

35-
datadir = load_test_data()
34+
35+
@pytest.fixture
36+
def datadir():
37+
try:
38+
from niworkflows.tests.data import load_test_data
39+
except ImportError:
40+
pytest.skip('niworkflows installed as wheel, data excluded')
41+
42+
return load_test_data()
3643

3744

3845
def _run_interface_mock(objekt, runtime):

niworkflows/tests/test_confounds.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from ..interfaces.confounds import ExpandModel, SpikeRegressors
3333
from ..interfaces.plotting import CompCorVariancePlot, ConfoundsCorrelationPlot
34-
from .conftest import datadir
3534

3635

3736
def _smoke_test_report(report_interface, artifact_name):
@@ -43,7 +42,7 @@ def _smoke_test_report(report_interface, artifact_name):
4342
assert os.path.isfile(out_report), f'Report "{out_report}" does not exist'
4443

4544

46-
def _expand_test(model_formula):
45+
def _expand_test(model_formula, datadir):
4746
orig_data_file = os.path.join(datadir, 'confounds_test.tsv')
4847
exp_data_file = (
4948
pe.Node(
@@ -56,7 +55,7 @@ def _expand_test(model_formula):
5655
return pd.read_csv(exp_data_file, sep='\t')
5756

5857

59-
def _spikes_test(lags=None, mincontig=None, fmt='mask'):
58+
def _spikes_test(lags=None, mincontig=None, fmt='mask', *, datadir):
6059
orig_data_file = os.path.join(datadir, 'spikes_test.tsv')
6160
lags = lags or [0]
6261
spk_data_file = (
@@ -78,7 +77,7 @@ def _spikes_test(lags=None, mincontig=None, fmt='mask'):
7877
return pd.read_csv(spk_data_file, sep='\t')
7978

8079

81-
def test_expansion_variable_selection():
80+
def test_expansion_variable_selection(datadir):
8281
"""Test model expansion: simple variable selection"""
8382
model_formula = 'a + b + c + d'
8483
expected_data = pd.DataFrame(
@@ -89,11 +88,11 @@ def test_expansion_variable_selection():
8988
'd': [9, 7, 5, 3, 1],
9089
}
9190
)
92-
exp_data = _expand_test(model_formula)
91+
exp_data = _expand_test(model_formula, datadir)
9392
pd.testing.assert_frame_equal(exp_data, expected_data)
9493

9594

96-
def test_expansion_derivatives_and_powers():
95+
def test_expansion_derivatives_and_powers(datadir):
9796
"""Temporal derivatives and quadratics"""
9897
model_formula = '(dd1(a) + d1(b))^^2 + d1-2((c)^2) + d + others'
9998
# b_derivative1_power2 is dropped as an exact duplicate of b_derivative1
@@ -112,13 +111,13 @@ def test_expansion_derivatives_and_powers():
112111
'f': [np.nan, 6, 4, 2, 0],
113112
}
114113
)
115-
exp_data = _expand_test(model_formula)
114+
exp_data = _expand_test(model_formula, datadir)
116115
assert set(exp_data.columns) == set(expected_data.columns)
117116
for col in expected_data.columns:
118117
pd.testing.assert_series_equal(expected_data[col], exp_data[col], check_dtype=False)
119118

120119

121-
def test_expansion_na_robustness():
120+
def test_expansion_na_robustness(datadir):
122121
"""NA robustness"""
123122
model_formula = '(dd1(f))^^2'
124123
expected_data = pd.DataFrame(
@@ -129,16 +128,16 @@ def test_expansion_na_robustness():
129128
'f_derivative1_power2': [np.nan, np.nan, 4, 4, 4],
130129
}
131130
)
132-
exp_data = _expand_test(model_formula)
131+
exp_data = _expand_test(model_formula, datadir)
133132
assert set(exp_data.columns) == set(expected_data.columns)
134133
for col in expected_data.columns:
135134
pd.testing.assert_series_equal(expected_data[col], exp_data[col], check_dtype=False)
136135

137136

138-
def test_spikes():
137+
def test_spikes(datadir):
139138
"""Test outlier flagging"""
140139
outliers = [1, 1, 0, 0, 1]
141-
spk_data = _spikes_test()
140+
spk_data = _spikes_test(datadir=datadir)
142141
assert np.all(np.isclose(outliers, spk_data['motion_outlier']))
143142

144143
outliers_spikes = pd.DataFrame(
@@ -148,30 +147,30 @@ def test_spikes():
148147
'motion_outlier02': [0, 0, 0, 0, 1],
149148
}
150149
)
151-
spk_data = _spikes_test(fmt='spikes')
150+
spk_data = _spikes_test(fmt='spikes', datadir=datadir)
152151
assert set(spk_data.columns) == set(outliers_spikes.columns)
153152
for col in outliers_spikes.columns:
154153
assert np.all(np.isclose(outliers_spikes[col], spk_data[col]))
155154

156155
lags = [0, 1]
157156
outliers_lags = [1, 1, 1, 0, 1]
158-
spk_data = _spikes_test(lags=lags)
157+
spk_data = _spikes_test(lags=lags, datadir=datadir)
159158
assert np.all(np.isclose(outliers_lags, spk_data['motion_outlier']))
160159

161160
mincontig = 2
162161
outliers_mc = [1, 1, 1, 1, 1]
163-
spk_data = _spikes_test(lags=lags, mincontig=mincontig)
162+
spk_data = _spikes_test(lags=lags, mincontig=mincontig, datadir=datadir)
164163
assert np.all(np.isclose(outliers_mc, spk_data['motion_outlier']))
165164

166165

167-
def test_CompCorVariancePlot():
166+
def test_CompCorVariancePlot(datadir):
168167
"""CompCor variance report test"""
169168
metadata_file = os.path.join(datadir, 'confounds_metadata_test.tsv')
170169
cc_rpt = CompCorVariancePlot(metadata_files=[metadata_file], metadata_sources=['aCompCor'])
171170
_smoke_test_report(cc_rpt, 'compcor_variance.svg')
172171

173172

174-
def test_ConfoundsCorrelationPlot():
173+
def test_ConfoundsCorrelationPlot(datadir):
175174
"""confounds correlation report test"""
176175
confounds_file = os.path.join(datadir, 'confounds_test.tsv')
177176
cc_rpt = ConfoundsCorrelationPlot(
@@ -182,7 +181,7 @@ def test_ConfoundsCorrelationPlot():
182181
_smoke_test_report(cc_rpt, 'confounds_correlation.svg')
183182

184183

185-
def test_ConfoundsCorrelationPlotColumns():
184+
def test_ConfoundsCorrelationPlotColumns(datadir):
186185
"""confounds correlation report test"""
187186
confounds_file = os.path.join(datadir, 'confounds_test.tsv')
188187
cc_rpt = ConfoundsCorrelationPlot(

niworkflows/tests/test_registration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
SpatialNormalizationRPT,
4040
)
4141
from ..testing import has_freesurfer, has_fsl
42-
from .conftest import _run_interface_mock, datadir
42+
from .conftest import _run_interface_mock
4343

4444

4545
def _smoke_test_report(report_interface, artifact_name):
@@ -61,7 +61,7 @@ def test_FLIRTRPT(reference, moving):
6161

6262

6363
@pytest.mark.skipif(not has_freesurfer, reason='No FreeSurfer')
64-
def test_MRICoregRPT(monkeypatch, reference, moving, nthreads):
64+
def test_MRICoregRPT(monkeypatch, reference, moving, nthreads, datadir):
6565
"""the MRICoreg report capable test"""
6666

6767
def _agg(objekt, runtime):
@@ -120,7 +120,7 @@ def test_FLIRTRPT_w_BBR(reference, reference_mask, moving):
120120

121121

122122
@pytest.mark.skipif(not has_freesurfer, reason='No FreeSurfer')
123-
def test_BBRegisterRPT(monkeypatch, moving):
123+
def test_BBRegisterRPT(monkeypatch, moving, datadir):
124124
"""the BBRegister report capable test"""
125125

126126
def _agg(objekt, runtime):
@@ -145,7 +145,7 @@ def _agg(objekt, runtime):
145145
_smoke_test_report(bbregister_rpt, 'testBBRegister.svg')
146146

147147

148-
def test_SpatialNormalizationRPT(monkeypatch, moving):
148+
def test_SpatialNormalizationRPT(monkeypatch, moving, datadir):
149149
"""the SpatialNormalizationRPT report capable test"""
150150

151151
def _agg(objekt, runtime):
@@ -164,7 +164,7 @@ def _agg(objekt, runtime):
164164
_smoke_test_report(ants_rpt, 'testSpatialNormalizationRPT.svg')
165165

166166

167-
def test_SpatialNormalizationRPT_masked(monkeypatch, moving, reference_mask):
167+
def test_SpatialNormalizationRPT_masked(monkeypatch, moving, reference_mask, datadir):
168168
"""the SpatialNormalizationRPT report capable test with masking"""
169169

170170
def _agg(objekt, runtime):
@@ -188,7 +188,7 @@ def _agg(objekt, runtime):
188188
_smoke_test_report(ants_rpt, 'testSpatialNormalizationRPT_masked.svg')
189189

190190

191-
def test_ANTSRegistrationRPT(monkeypatch, reference, moving):
191+
def test_ANTSRegistrationRPT(monkeypatch, reference, moving, datadir):
192192
"""the SpatialNormalizationRPT report capable test"""
193193
from niworkflows import data
194194

niworkflows/tests/test_segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from ..interfaces.reportlets.segmentation import FASTRPT, ReconAllRPT
4040
from ..testing import has_freesurfer, has_fsl
41-
from .conftest import _run_interface_mock, datadir
41+
from .conftest import _run_interface_mock
4242

4343

4444
def _smoke_test_report(report_interface, artifact_name):
@@ -160,7 +160,7 @@ def test_SimpleShowMaskRPT():
160160
_smoke_test_report(msk_rpt, 'testSimpleMask.svg')
161161

162162

163-
def test_BrainExtractionRPT(monkeypatch, moving, nthreads):
163+
def test_BrainExtractionRPT(monkeypatch, moving, nthreads, datadir):
164164
"""test antsBrainExtraction with reports"""
165165

166166
def _agg(objekt, runtime):
@@ -201,7 +201,7 @@ def _agg(objekt, runtime):
201201

202202
@pytest.mark.skipif(not has_fsl, reason='No FSL')
203203
@pytest.mark.parametrize('segments', [True, False])
204-
def test_FASTRPT(monkeypatch, segments, reference, reference_mask):
204+
def test_FASTRPT(monkeypatch, segments, reference, reference_mask, datadir):
205205
"""test FAST with the two options for segments"""
206206
from nipype.interfaces.fsl.maths import ApplyMask
207207

niworkflows/tests/test_viz.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from niworkflows.viz.plots import fMRIPlot
3636

3737
from .. import viz
38-
from .conftest import datadir
3938
from .generate_data import _create_dtseries_cifti
4039

4140

@@ -135,7 +134,7 @@ def test_carpetplot(tr, sorting):
135134
),
136135
],
137136
)
138-
def test_fmriplot(input_files):
137+
def test_fmriplot(input_files, datadir):
139138
"""Exercise the fMRIPlot class."""
140139
save_artifacts = os.getenv('SAVE_CIRCLE_ARTIFACTS')
141140
rng = np.random.default_rng(2010)
@@ -249,7 +248,7 @@ def test_plot_melodic_components(tmp_path):
249248
)
250249

251250

252-
def test_compcor_variance_plot(tmp_path):
251+
def test_compcor_variance_plot(tmp_path, datadir):
253252
"""Test plotting CompCor variance"""
254253
out_dir = Path(os.getenv('SAVE_CIRCLE_ARTIFACTS', str(tmp_path)))
255254
out_file = str(out_dir / 'variance_plot_short.svg')

0 commit comments

Comments
 (0)