Skip to content

Commit cdf87b7

Browse files
committed
TEST: Add smoke tests for full workflow and most branching flags
1 parent dfa59db commit cdf87b7

File tree

1 file changed

+166
-5
lines changed

1 file changed

+166
-5
lines changed

fmriprep/workflows/tests/test_base.py

Lines changed: 166 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from copy import deepcopy
2+
from pathlib import Path
3+
from unittest.mock import patch
24

35
import bids
6+
import nibabel as nb
7+
import numpy as np
8+
import pytest
9+
from nipype.pipeline.engine.utils import generate_expanded_graph
410
from niworkflows.utils.testing import generate_bids_skeleton
511
from sdcflows.fieldmaps import clear_registry
612
from sdcflows.utils.wrangler import find_estimators
713

8-
from ..base import get_estimator
14+
from ... import config
15+
from ..base import get_estimator, init_fmriprep_wf
16+
from ..tests import mock_config
917

1018
BASE_LAYOUT = {
1119
"01": {
@@ -19,7 +27,7 @@
1927
{
2028
"task": "rest",
2129
"run": i,
22-
"suffix": "bold",
30+
"suffix": suffix,
2331
"metadata": {
2432
"RepetitionTime": 2.0,
2533
"PhaseEncodingDirection": "j",
@@ -28,6 +36,7 @@
2836
"SliceTiming": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8],
2937
},
3038
}
39+
for suffix in ("bold", "sbref")
3140
for i in range(1, 3)
3241
),
3342
*(
@@ -64,6 +73,161 @@
6473
}
6574

6675

76+
@pytest.fixture(scope="module", autouse=True)
77+
def _quiet_logger():
78+
import logging
79+
80+
logger = logging.getLogger("nipype.workflow")
81+
old_level = logger.getEffectiveLevel()
82+
logger.setLevel(logging.ERROR)
83+
yield
84+
logger.setLevel(old_level)
85+
86+
87+
@pytest.fixture(autouse=True)
88+
def _reset_sdcflows_registry():
89+
yield
90+
clear_registry()
91+
92+
93+
@pytest.fixture(scope="module")
94+
def bids_root(tmp_path_factory):
95+
base = tmp_path_factory.mktemp("base")
96+
bids_dir = base / "bids"
97+
generate_bids_skeleton(bids_dir, BASE_LAYOUT)
98+
99+
img = nb.Nifti1Image(np.zeros((10, 10, 10, 10)), np.eye(4))
100+
101+
for bold_path in bids_dir.glob('sub-01/*/*.nii.gz'):
102+
img.to_filename(bold_path)
103+
104+
yield bids_dir
105+
106+
107+
def _make_params(
108+
bold2t1w_init: str = "register",
109+
use_bbr: bool | None = None,
110+
dummy_scans: int | None = None,
111+
me_output_echos: bool = False,
112+
medial_surface_nan: bool = False,
113+
project_goodvoxels: bool = False,
114+
cifti_output: bool | str = False,
115+
run_msmsulc: bool = True,
116+
skull_strip_t1w: str = "auto",
117+
use_syn_sdc: str | bool = False,
118+
force_syn: bool = False,
119+
freesurfer: bool = True,
120+
ignore: list[str] = None,
121+
bids_filters: dict = None,
122+
):
123+
if ignore is None:
124+
ignore = []
125+
if bids_filters is None:
126+
bids_filters = {}
127+
return (
128+
bold2t1w_init,
129+
use_bbr,
130+
dummy_scans,
131+
me_output_echos,
132+
medial_surface_nan,
133+
project_goodvoxels,
134+
cifti_output,
135+
run_msmsulc,
136+
skull_strip_t1w,
137+
use_syn_sdc,
138+
force_syn,
139+
freesurfer,
140+
ignore,
141+
bids_filters,
142+
)
143+
144+
145+
@pytest.mark.parametrize("level", ["minimal", "resampling", "full"])
146+
@pytest.mark.parametrize("anat_only", [False, True])
147+
@pytest.mark.parametrize(
148+
(
149+
"bold2t1w_init",
150+
"use_bbr",
151+
"dummy_scans",
152+
"me_output_echos",
153+
"medial_surface_nan",
154+
"project_goodvoxels",
155+
"cifti_output",
156+
"run_msmsulc",
157+
"skull_strip_t1w",
158+
"use_syn_sdc",
159+
"force_syn",
160+
"freesurfer",
161+
"ignore",
162+
"bids_filters",
163+
),
164+
[
165+
_make_params(),
166+
_make_params(bold2t1w_init="header"),
167+
_make_params(use_bbr=True),
168+
_make_params(use_bbr=False),
169+
_make_params(bold2t1w_init="header", use_bbr=True),
170+
# Currently disabled
171+
# _make_params(bold2t1w_init="header", use_bbr=False),
172+
_make_params(dummy_scans=2),
173+
_make_params(me_output_echos=True),
174+
_make_params(medial_surface_nan=True),
175+
_make_params(cifti_output='91k'),
176+
_make_params(cifti_output='91k', project_goodvoxels=True),
177+
_make_params(cifti_output='91k', project_goodvoxels=True, run_msmsulc=False),
178+
_make_params(cifti_output='91k', run_msmsulc=False),
179+
_make_params(skull_strip_t1w='force'),
180+
_make_params(skull_strip_t1w='skip'),
181+
_make_params(use_syn_sdc='warn', force_syn=True, ignore=['fieldmaps']),
182+
_make_params(freesurfer=False),
183+
_make_params(freesurfer=False, use_bbr=True),
184+
_make_params(freesurfer=False, use_bbr=False),
185+
# Currently unsupported:
186+
# _make_params(freesurfer=False, bold2t1w_init="header"),
187+
# _make_params(freesurfer=False, bold2t1w_init="header", use_bbr=True),
188+
# _make_params(freesurfer=False, bold2t1w_init="header", use_bbr=False),
189+
],
190+
)
191+
def test_init_fmriprep_wf(
192+
bids_root: Path,
193+
tmp_path: Path,
194+
level: str,
195+
anat_only: bool,
196+
bold2t1w_init: str,
197+
use_bbr: bool | None,
198+
dummy_scans: int | None,
199+
me_output_echos: bool,
200+
medial_surface_nan: bool,
201+
project_goodvoxels: bool,
202+
cifti_output: bool | str,
203+
run_msmsulc: bool,
204+
skull_strip_t1w: str,
205+
use_syn_sdc: str | bool,
206+
force_syn: bool,
207+
freesurfer: bool,
208+
ignore: list[str],
209+
bids_filters: dict,
210+
):
211+
with mock_config(bids_dir=bids_root):
212+
config.workflow.level = level
213+
config.workflow.anat_only = anat_only
214+
config.workflow.bold2t1w_init = bold2t1w_init
215+
config.workflow.use_bbr = use_bbr
216+
config.workflow.dummy_scans = dummy_scans
217+
config.execution.me_output_echos = me_output_echos
218+
config.workflow.medial_surface_nan = medial_surface_nan
219+
config.workflow.project_goodvoxels = project_goodvoxels
220+
config.workflow.run_msmsulc = run_msmsulc
221+
config.workflow.skull_strip_t1w = skull_strip_t1w
222+
config.workflow.cifti_output = cifti_output
223+
config.workflow.run_reconall = freesurfer
224+
config.workflow.ignore = ignore
225+
with patch.dict('fmriprep.config.execution.bids_filters', bids_filters):
226+
wf = init_fmriprep_wf()
227+
228+
generate_expanded_graph(wf._create_flat_graph())
229+
230+
67231
def test_get_estimator_none(tmp_path):
68232
bids_dir = tmp_path / "bids"
69233

@@ -100,7 +264,6 @@ def test_get_estimator_b0field_and_intendedfor(tmp_path):
100264

101265
assert get_estimator(layout, bold_files[0]) == ('epi',)
102266
assert get_estimator(layout, bold_files[1]) == ('auto_00000',)
103-
clear_registry()
104267

105268

106269
def test_get_estimator_overlapping_specs(tmp_path):
@@ -130,7 +293,6 @@ def test_get_estimator_overlapping_specs(tmp_path):
130293
# B0Fields take precedence
131294
assert get_estimator(layout, bold_files[0]) == ('epi',)
132295
assert get_estimator(layout, bold_files[1]) == ('epi',)
133-
clear_registry()
134296

135297

136298
def test_get_estimator_multiple_b0fields(tmp_path):
@@ -156,4 +318,3 @@ def test_get_estimator_multiple_b0fields(tmp_path):
156318
# Always get an iterable; don't care if it's a list or tuple
157319
assert get_estimator(layout, bold_files[0]) == ['epi', 'phasediff']
158320
assert get_estimator(layout, bold_files[1]) == ('epi',)
159-
clear_registry()

0 commit comments

Comments
 (0)