Skip to content

Commit 1fa466a

Browse files
authored
Merge pull request #169 from nipreps/enh/standard-ref-workflow
ENH: Use *NiWorkflows*' EPI-reference workflow
2 parents c5ae564 + 48ac04a commit 1fa466a

File tree

4 files changed

+52
-332
lines changed

4 files changed

+52
-332
lines changed

dmriprep/workflows/dwi/base.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,17 @@ def init_dwi_preproc_wf(dwi_file, has_fieldmap=False):
6464
6565
See Also
6666
--------
67-
* :py:func:`~dmriprep.workflows.dwi.util.init_dwi_reference_wf`
6867
* :py:func:`~dmriprep.workflows.dwi.outputs.init_dwi_derivatives_wf`
6968
* :py:func:`~dmriprep.workflows.dwi.outputs.init_reportlets_wf`
7069
7170
"""
7271
from niworkflows.interfaces.reportlets.registration import (
7372
SimpleBeforeAfterRPT as SimpleBeforeAfter,
7473
)
74+
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
75+
from sdcflows.workflows.ancillary import init_brainextraction_wf
76+
7577
from ...interfaces.vectors import CheckGradientTable
76-
from .util import init_dwi_reference_wf
7778
from .outputs import init_dwi_derivatives_wf, init_reportlets_wf
7879
from .eddy import init_eddy_wf
7980

@@ -143,26 +144,40 @@ def init_dwi_preproc_wf(dwi_file, has_fieldmap=False):
143144

144145
gradient_table = pe.Node(CheckGradientTable(), name="gradient_table")
145146

146-
dwi_reference_wf = init_dwi_reference_wf(
147-
mem_gb=config.DEFAULT_MEMORY_MIN_GB, omp_nthreads=config.nipype.omp_nthreads
147+
dwi_reference_wf = init_epi_reference_wf(
148+
omp_nthreads=config.nipype.omp_nthreads,
149+
name="dwi_reference_wf",
150+
)
151+
152+
brainextraction_wf = init_brainextraction_wf()
153+
dwi_derivatives_wf = init_dwi_derivatives_wf(
154+
output_dir=str(config.execution.output_dir)
148155
)
149156

150-
dwi_derivatives_wf = init_dwi_derivatives_wf(output_dir=str(config.execution.output_dir))
157+
# If has_fieldmaps this will hold the corrected reference, original otherwise
158+
buffernode = pe.Node(
159+
niu.IdentityInterface(fields=["dwi_reference", "dwi_mask"]),
160+
name="buffernode",
161+
)
151162

152163
# MAIN WORKFLOW STRUCTURE
153164
# fmt: off
154165
workflow.connect([
166+
(inputnode, dwi_derivatives_wf, [("dwi_file", "inputnode.source_file")]),
155167
(inputnode, gradient_table, [("dwi_file", "dwi_file"),
156168
("in_bvec", "in_bvec"),
157169
("in_bval", "in_bval")]),
158-
(inputnode, dwi_reference_wf, [("dwi_file", "inputnode.dwi_file")]),
159-
(inputnode, dwi_derivatives_wf, [("dwi_file", "inputnode.source_file")]),
160-
(gradient_table, dwi_reference_wf, [("b0_ixs", "inputnode.b0_ixs")]),
161-
(gradient_table, outputnode, [("out_rasb", "gradients_rasb")]),
162-
(outputnode, dwi_derivatives_wf, [
170+
(inputnode, dwi_reference_wf, [(("dwi_file", _aslist), "inputnode.in_files")]),
171+
(dwi_reference_wf, brainextraction_wf, [
172+
("outputnode.epi_ref_file", "inputnode.in_file")]),
173+
(gradient_table, dwi_reference_wf, [(("b0_mask", _aslist), "inputnode.t_masks")]),
174+
(buffernode, dwi_derivatives_wf, [
163175
("dwi_reference", "inputnode.dwi_ref"),
164176
("dwi_mask", "inputnode.dwi_mask"),
165177
]),
178+
(buffernode, outputnode, [("dwi_reference", "dwi_reference"),
179+
("dwi_mask", "dwi_mask")]),
180+
(gradient_table, outputnode, [("out_rasb", "gradients_rasb")]),
166181
])
167182
# fmt: on
168183

@@ -204,12 +219,10 @@ def _bold_reg_suffix(fallback):
204219
("t1w_mask", "in_mask")]),
205220
(inputnode, ds_report_reg, [("dwi_file", "source_file")]),
206221
# BBRegister
207-
(dwi_reference_wf, bbr_wf, [
208-
("outputnode.ref_image", "inputnode.in_file")
209-
]),
222+
(buffernode, bbr_wf, [("dwi_reference", "inputnode.in_file")]),
210223
(bbr_wf, ds_report_reg, [
211-
('outputnode.out_report', 'in_file'),
212-
(('outputnode.fallback', _bold_reg_suffix), 'desc')]),
224+
("outputnode.out_report", "in_file"),
225+
(("outputnode.fallback", _bold_reg_suffix), "desc")]),
213226
])
214227
# fmt: on
215228

@@ -239,17 +252,13 @@ def _bold_reg_suffix(fallback):
239252

240253
# fmt:off
241254
workflow.connect([
242-
(dwi_reference_wf, eddy_wf, [
243-
("outputnode.dwi_file", "inputnode.dwi_file"),
244-
("outputnode.dwi_mask", "inputnode.dwi_mask"),
245-
]),
246-
(inputnode, eddy_wf, [
247-
("in_bvec", "inputnode.in_bvec"),
248-
("in_bval", "inputnode.in_bval")
249-
]),
250-
(dwi_reference_wf, eddy_report, [("outputnode.ref_image", "before")]),
251-
(eddy_wf, eddy_report, [('outputnode.eddy_ref_image', 'after')]),
252-
(dwi_reference_wf, ds_report_eddy, [("outputnode.dwi_file", "source_file")]),
255+
(inputnode, eddy_wf, [("dwi_file", "inputnode.dwi_file"),
256+
("in_bvec", "inputnode.in_bvec"),
257+
("in_bval", "inputnode.in_bval")]),
258+
(inputnode, ds_report_eddy, [("dwi_file", "source_file")]),
259+
(brainextraction_wf, eddy_wf, [("outputnode.out_mask", "inputnode.dwi_mask")]),
260+
(brainextraction_wf, eddy_report, [("outputnode.out_file", "before")]),
261+
(eddy_wf, eddy_report, [("outputnode.eddy_ref_image", "after")]),
253262
(eddy_report, ds_report_eddy, [("out_report", "in_file")]),
254263
])
255264
# fmt:on
@@ -275,8 +284,10 @@ def _bold_reg_suffix(fallback):
275284
if not has_fieldmap:
276285
# fmt: off
277286
workflow.connect([
278-
(dwi_reference_wf, outputnode, [("outputnode.ref_image", "dwi_reference"),
279-
("outputnode.dwi_mask", "dwi_mask")]),
287+
(brainextraction_wf, buffernode, [
288+
("outputnode.out_file", "dwi_reference"),
289+
("outputnode.out_mask", "dwi_mask"),
290+
]),
280291
])
281292
# fmt: on
282293
return workflow
@@ -328,16 +339,15 @@ def _bold_reg_suffix(fallback):
328339
("fmap_coeff", "inputnode.fmap_coeff"),
329340
("fmap_mask", "inputnode.fmap_mask")]),
330341
(dwi_reference_wf, coeff2epi_wf, [
331-
("outputnode.ref_image", "inputnode.target_ref"),
332-
("outputnode.dwi_mask", "inputnode.target_mask")]),
333-
(dwi_reference_wf, unwarp_wf, [("outputnode.ref_image", "inputnode.distorted")]),
342+
("outputnode.epi_ref_file", "inputnode.target_ref")]),
343+
(dwi_reference_wf, unwarp_wf, [("outputnode.epi_ref_file", "inputnode.distorted")]),
334344
(coeff2epi_wf, unwarp_wf, [
335345
("outputnode.fmap_coeff", "inputnode.fmap_coeff")]),
336-
(dwi_reference_wf, sdc_report, [("outputnode.ref_image", "before")]),
346+
(brainextraction_wf, sdc_report, [("outputnode.out_file", "before")]),
337347
(unwarp_wf, sdc_report, [("outputnode.corrected", "after"),
338348
("outputnode.corrected_mask", "wm_seg")]),
339349
(sdc_report, reportlets_wf, [("out_report", "inputnode.sdc_report")]),
340-
(unwarp_wf, outputnode, [("outputnode.corrected", "dwi_reference"),
350+
(unwarp_wf, buffernode, [("outputnode.corrected", "dwi_reference"),
341351
("outputnode.corrected_mask", "dwi_mask")]),
342352
])
343353
# fmt: on
@@ -351,10 +361,10 @@ def _get_wf_name(filename):
351361
352362
Examples
353363
--------
354-
>>> _get_wf_name('/completely/made/up/path/sub-01_dir-AP_acq-64grad_dwi.nii.gz')
364+
>>> _get_wf_name("/completely/made/up/path/sub-01_dir-AP_acq-64grad_dwi.nii.gz")
355365
'dwi_preproc_dir_AP_acq_64grad_wf'
356366
357-
>>> _get_wf_name('/completely/made/up/path/sub-01_dir-RL_run-01_echo-1_dwi.nii.gz')
367+
>>> _get_wf_name("/completely/made/up/path/sub-01_dir-RL_run-01_echo-1_dwi.nii.gz")
358368
'dwi_preproc_dir_RL_run_01_echo_1_wf'
359369
360370
"""
@@ -363,3 +373,7 @@ def _get_wf_name(filename):
363373
fname = Path(filename).name.rpartition(".nii")[0].replace("_dwi", "_wf")
364374
fname_nosub = "_".join(fname.split("_")[1:])
365375
return f"dwi_preproc_{fname_nosub.replace('.', '_').replace(' ', '').replace('-', '_')}"
376+
377+
378+
def _aslist(value):
379+
return [value]

0 commit comments

Comments
 (0)