Skip to content

Commit d0ac22c

Browse files
committed
enh: general refactor
- Connections to the particular ``dwi_file`` are done within the particular run's preproc workflow, using ``KeySelect`` to demux the available fieldmap options. This relies on implementing nipreps/sdcflows#147 and nipreps/sdcflows#148. - Minimizes the overhead in ``dmriprep/workflows/base.py`` - Run black
1 parent 2edc4b3 commit d0ac22c

File tree

2 files changed

+121
-81
lines changed

2 files changed

+121
-81
lines changed

dmriprep/workflows/base.py

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def init_single_subject_wf(subject_id):
287287
return workflow
288288

289289
from .dwi.base import init_dwi_preproc_wf
290+
290291
# Append the dMRI section to the existing anatomical excerpt
291292
# That way we do not need to stream down the number of DWI datasets
292293
anat_preproc_wf.__postdesc__ = (
@@ -300,77 +301,87 @@ def init_single_subject_wf(subject_id):
300301
and a *b=0* average for reference to the subsequent steps of preprocessing was calculated.
301302
"""
302303
)
303-
dwi_preproc_list = []
304-
for dwi_file in subject_data["dwi"]:
305-
dwi_preproc_wf = init_dwi_preproc_wf(dwi_file)
306-
dwi_preproc_list.append(dwi_preproc_wf)
307304

308-
dwi_preproc_list_wf = pe.Node(niu.IdentityInterface(fields=["dwi_workflows"]),
309-
name="dwi_preproc_list_wf")
310-
dwi_preproc_list_wf.iterables = [("dwi_workflows", dwi_preproc_list)]
305+
# SDC Step 0: Determine whether fieldmaps can/should be estimated
306+
fmap_estimators = None
307+
if "fieldmap" not in config.workflow.ignore:
308+
from sdcflows import fieldmaps as fm
309+
from sdcflows.utils.wrangler import find_estimators
310+
from sdcflows.workflows.base import init_fmap_preproc_wf
311311

312-
# fmt: off
313-
workflow.connect([
314-
(anat_preproc_wf, dwi_preproc_list_wf, [
315-
("outputnode.t1w_preproc", "inputnode.t1w_preproc"),
316-
("outputnode.t1w_mask", "inputnode.t1w_mask"),
317-
("outputnode.t1w_dseg", "inputnode.t1w_dseg"),
318-
("outputnode.t1w_aseg", "inputnode.t1w_aseg"),
319-
("outputnode.t1w_aparc", "inputnode.t1w_aparc"),
320-
("outputnode.t1w_tpms", "inputnode.t1w_tpms"),
321-
("outputnode.template", "inputnode.template"),
322-
("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"),
323-
("outputnode.std2anat_xfm", "inputnode.std2anat_xfm"),
324-
# Undefined if --fs-no-reconall, but this is safe
325-
("outputnode.subjects_dir", "inputnode.subjects_dir"),
326-
("outputnode.t1w2fsnative_xfm", "inputnode.t1w2fsnative_xfm"),
327-
("outputnode.fsnative2t1w_xfm", "inputnode.fsnative2t1w_xfm"),
328-
]),
329-
(bids_info, dwi_preproc_list_wf, [("subject", "inputnode.subject_id")]),
330-
])
331-
# fmt: on
312+
# SDC Step 1: Run basic heuristics to identify available data for fieldmap estimation
313+
fmap_estimators = find_estimators(config.execution.layout)
332314

333-
if "fieldmap" in config.workflow.ignore:
334-
return workflow
315+
# Add fieldmap-less estimators
316+
if not fmap_estimators and config.workflow.use_syn:
317+
# estimators = [fm.FieldmapEstimation()]
318+
raise NotImplementedError
335319

336-
from sdcflows import fieldmaps as fm
337-
from sdcflows.utils.wrangler import find_estimators
338-
from sdcflows.workflows.base import init_fmap_preproc_wf
320+
# Nuts and bolts: initialize individual run's pipeline
321+
dwi_preproc_list = []
322+
for dwi_file in subject_data["dwi"]:
323+
dwi_preproc_wf = init_dwi_preproc_wf(
324+
dwi_file,
325+
has_fieldmap=bool(fmap_estimators),
326+
)
339327

340-
# SDCFlows connection
341-
# Step 1: Run basic heuristics to identify available data for fieldmap estimation
342-
estimators = find_estimators(config.execution.layout)
328+
# fmt: off
329+
workflow.connect([
330+
(anat_preproc_wf, dwi_preproc_wf, [
331+
("outputnode.t1w_preproc", "inputnode.t1w_preproc"),
332+
("outputnode.t1w_mask", "inputnode.t1w_mask"),
333+
("outputnode.t1w_dseg", "inputnode.t1w_dseg"),
334+
("outputnode.t1w_aseg", "inputnode.t1w_aseg"),
335+
("outputnode.t1w_aparc", "inputnode.t1w_aparc"),
336+
("outputnode.t1w_tpms", "inputnode.t1w_tpms"),
337+
("outputnode.template", "inputnode.template"),
338+
("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"),
339+
("outputnode.std2anat_xfm", "inputnode.std2anat_xfm"),
340+
# Undefined if --fs-no-reconall, but this is safe
341+
("outputnode.subjects_dir", "inputnode.subjects_dir"),
342+
("outputnode.t1w2fsnative_xfm", "inputnode.t1w2fsnative_xfm"),
343+
("outputnode.fsnative2t1w_xfm", "inputnode.fsnative2t1w_xfm"),
344+
]),
345+
(bids_info, dwi_preproc_wf, [("subject", "inputnode.subject_id")]),
346+
])
347+
# fmt: on
348+
349+
# Keep a handle to each workflow
350+
dwi_preproc_list.append(dwi_preproc_wf)
343351

344-
if not estimators and config.workflow.use_syn: # Add fieldmap-less estimators
345-
# estimators = [fm.FieldmapEstimation()]
346-
raise NotImplementedError
352+
if not fmap_estimators:
353+
return workflow
347354

348-
# Step 2: Manually add further estimators (e.g., fieldmap-less)
355+
# SDC Step 2: Manually add further estimators (e.g., fieldmap-less)
349356
fmap_wf = init_fmap_preproc_wf(
350357
debug=config.execution.debug,
351-
estimators=estimators,
358+
estimators=fmap_estimators,
352359
omp_nthreads=config.nipype.omp_nthreads,
353360
output_dir=str(output_dir),
354361
subject=subject_id,
355362
)
356-
# fmt: off
357-
workflow.connect([
358-
(fmap_wf, dwi_preproc_list_wf, [
359-
("outputnode.fmap", "inputnode.fmap"),
360-
("outputnode.fmap_ref", "inputnode.fmap_ref"),
361-
("outputnode.fmap_coeff", "inputnode.fmap_coeff"),
362-
("outputnode.fmap_mask", "inputnode.fmap_mask"),
363-
]),
364-
])
365-
# fmt: on
363+
364+
# TODO: Requires nipreps/sdcflows#147
365+
for dwi_preproc_wf in dwi_preproc_list:
366+
# fmt: off
367+
workflow.connect([
368+
(fmap_wf, dwi_preproc_wf, [
369+
("outputnode.fmap", "inputnode.fmap"),
370+
("outputnode.fmap_ref", "inputnode.fmap_ref"),
371+
("outputnode.fmap_coeff", "inputnode.fmap_coeff"),
372+
("outputnode.fmap_mask", "inputnode.fmap_mask"),
373+
("outputnode.fmap_id", "inputnode.fmap_id"),
374+
]),
375+
])
376+
# fmt: on
377+
366378
# Overwrite ``out_path_base`` of sdcflows's DataSinks
367379
for node in fmap_wf.list_node_names():
368380
if node.split(".")[-1].startswith("ds_"):
369381
fmap_wf.get_node(node).interface.out_path_base = "dmriprep"
370-
workflow.add_nodes([fmap_wf])
371382

372383
# Step 3: Manually connect PEPOLAR
373-
for estimator in estimators:
384+
for estimator in fmap_estimators:
374385
if estimator.method != fm.EstimatorType.PEPOLAR:
375386
continue
376387

@@ -387,18 +398,18 @@ def init_single_subject_wf(subject_id):
387398
raise NotImplementedError
388399
# from niworkflows.interfaces.utility import KeySelect
389400
# est_id = estimator.bids_id
390-
# fmap_select = pe.MapNode(
401+
# estim_select = pe.MapNode(
391402
# KeySelect(fields=["metadata", "dwi_reference", "dwi_mask", "gradients_rasb",]),
392403
# name=f"fmap_select_{est_id}",
393404
# run_without_submitting=True,
394405
# iterfields=["key"]
395406
# )
396-
# fmap_select.inputs.key = [
407+
# estim_select.inputs.key = [
397408
# str(s.path) for s in estimator.sources if s.suffix in ("epi", "dwi", "sbref")
398409
# ]
399410
# # fmt:off
400411
# workflow.connect([
401-
# (referencenode, fmap_select, [("dwi_file", "keys"),
412+
# (referencenode, estim_select, [("dwi_file", "keys"),
402413
# ("metadata", "metadata"),
403414
# ("dwi_reference", "dwi_reference"),
404415
# ("gradients_rasb", "gradients_rasb")]),

dmriprep/workflows/dwi/base.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def init_dwi_preproc_wf(dwi_file, has_fieldmap=False):
8383
"dwi_file",
8484
"in_bvec",
8585
"in_bval",
86-
# From fmap
86+
# From SDCFlows
8787
"fmap",
8888
"fmap_ref",
8989
"fmap_coeff",
@@ -129,10 +129,6 @@ def init_dwi_preproc_wf(dwi_file, has_fieldmap=False):
129129
("in_bval", "in_bval")]),
130130
(inputnode, dwi_reference_wf, [("dwi_file", "inputnode.dwi_file")]),
131131
(gradient_table, dwi_reference_wf, [("b0_ixs", "inputnode.b0_ixs")]),
132-
#outputnode, [
133-
# ("outputnode.ref_image", "dwi_reference"),
134-
# ("outputnode.dwi_mask", "dwi_mask"),
135-
#]),
136132
(gradient_table, outputnode, [("out_rasb", "gradients_rasb")]),
137133
])
138134
# fmt: on
@@ -184,35 +180,68 @@ def _bold_reg_suffix(fallback):
184180
])
185181
# fmt: on
186182

187-
if has_fieldmap:
188-
from sdcflows.workflows.apply.registration import init_coeff2epi_wf
189-
from sdcflows.workflows.apply.correction import init_unwarp_wf
190-
191-
coeff2epi_wf = init_coeff2epi_wf(
192-
omp_nthreads=config.nipype.omp_nthreads, write_coeff=True
193-
)
194-
unwarp_wf = init_unwarp_wf(omp_nthreads=config.nipype.omp_nthreads)
195-
196-
# fmt: off
197-
workflow.connect([
198-
(dwi_reference_wf, coeff2epi_wf, [
199-
("outputnode.ref_image", "inputnode.target_ref"),
200-
("outputnode.dwi_mask", "inputnode.target_mask")]),
201-
(coeff2epi_wf, unwarp_wf, [
202-
("outputnode.fmap_coeff", "inputnode.fmap_coeff")])
203-
])
204-
# fmt: on
205-
206183
# REPORTING ############################################################
207184
reportlets_wf = init_reportlets_wf(str(config.execution.output_dir))
208185
# fmt: off
209186
workflow.connect([
210187
(inputnode, reportlets_wf, [("dwi_file", "inputnode.source_file")]),
211188
(dwi_reference_wf, reportlets_wf, [
212-
("outputnode.ref_image", "inputnode.dwi_ref"),
213-
("outputnode.dwi_mask", "inputnode.dwi_mask"),
214189
("outputnode.validation_report", "inputnode.validation_report"),
215190
]),
191+
(outputnode, reportlets_wf, [
192+
("dwi_reference", "inputnode.dwi_ref"),
193+
("dwi_mask", "inputnode.dwi_mask"),
194+
]),
195+
])
196+
# fmt: on
197+
198+
if not has_fieldmap:
199+
# fmt: off
200+
workflow.connect([
201+
(dwi_reference_wf, outputnode, [("outputnode.ref_image", "dwi_reference"),
202+
("outputnode.dwi_mask", "dwi_mask")]),
203+
])
204+
# fmt: on
205+
return workflow
206+
207+
from niworkflows.interfaces.utility import KeySelect
208+
from sdcflows.workflows.apply.registration import init_coeff2epi_wf
209+
from sdcflows.workflows.apply.correction import init_unwarp_wf
210+
211+
# TODO: Requires nipreps/sdcflows#148
212+
# from sdcflows.utils.fieldmap import get_identifier
213+
214+
coeff2epi_wf = init_coeff2epi_wf(
215+
omp_nthreads=config.nipype.omp_nthreads, write_coeff=True
216+
)
217+
unwarp_wf = init_unwarp_wf(omp_nthreads=config.nipype.omp_nthreads)
218+
unwarp_wf.inputs.inputnode.metadata = layout.get_metadata(dwi_file)
219+
220+
output_select = pe.Node(
221+
KeySelect(fields=["fmap", "fmap_ref", "fmap_coeff", "fmap_mask"]),
222+
name="output_select",
223+
run_without_submitting=True,
224+
)
225+
# output_select.inputs.key = get_identifier(dwi_file)
226+
227+
# fmt: off
228+
workflow.connect([
229+
(inputnode, output_select, [("fmap", "fmap"),
230+
("fmap_ref", "fmap_ref"),
231+
("fmap_coeff", "fmap_coeff"),
232+
("fmap_mask", "fmap_mask"),
233+
("fmap_id", "keys")]),
234+
(output_select, coeff2epi_wf, [
235+
("fmap_ref", "inputnode.fmap_ref"),
236+
("fmap_coeff", "inputnode.fmap_coeff"),
237+
("fmap_mask", "inputnode.fmap_mask")]),
238+
(dwi_reference_wf, coeff2epi_wf, [
239+
("outputnode.ref_image", "inputnode.target_ref"),
240+
("outputnode.dwi_mask", "inputnode.target_mask")]),
241+
(dwi_reference_wf, unwarp_wf, [("outputnode.ref_image", "distorted")]),
242+
(coeff2epi_wf, unwarp_wf, [
243+
("outputnode.fmap_coeff", "inputnode.fmap_coeff")]),
244+
(unwarp_wf, outputnode, [("outputnode.corrected", "dwi_reference")]),
216245
])
217246
# fmt: on
218247

0 commit comments

Comments
 (0)