Skip to content

Commit b7c4e3c

Browse files
authored
Merge pull request #268 from mgxd/enh/runwise-reference-image
ENH: Runwise bold reference generation
2 parents 737e62c + 7718b99 commit b7c4e3c

File tree

4 files changed

+177
-112
lines changed

4 files changed

+177
-112
lines changed

nibabies/workflows/base.py

Lines changed: 32 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -408,80 +408,44 @@ def init_single_subject_wf(subject_id, session_id=None):
408408

409409
# Append the functional section to the existing anatomical exerpt
410410
# That way we do not need to stream down the number of bold datasets
411-
anat_preproc_wf.__postdesc__ = (
412-
(anat_preproc_wf.__postdesc__ if hasattr(anat_preproc_wf, "__postdesc__") else "")
413-
+ f"""
411+
anat_preproc_wf.__postdesc__ = getattr(anat_preproc_wf, '__postdesc__') or ''
412+
func_pre_desc = f"""
414413
415414
Functional data preprocessing
416415
417416
: For each of the {len(subject_data['bold'])} BOLD runs found per subject (across all
418-
tasks and sessions), the following preprocessing was performed.
419-
"""
420-
)
421-
422-
# calculate reference image(s) for BOLD images
423-
# group all BOLD files based on same:
424-
# 1) session
425-
# 2) PE direction
426-
# 3) total readout time
427-
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
428-
429-
bold_groupings = group_bolds_ref(
430-
layout=config.execution.layout,
431-
subject=subject_id,
432-
sessions=[session_id],
433-
)
417+
tasks and sessions), the following preprocessing was performed."""
434418

435419
func_preproc_wfs = []
436420
has_fieldmap = bool(fmap_estimators)
437-
for idx, grouping in enumerate(bold_groupings.values()):
438-
bold_ref_wf = init_epi_reference_wf(
439-
auto_bold_nss=True,
440-
name=f"bold_reference_wf{idx}",
441-
omp_nthreads=config.nipype.omp_nthreads,
442-
)
443-
bold_files = grouping.files
444-
bold_ref_wf.inputs.inputnode.in_files = grouping.files
445-
446-
if grouping.multiecho_id is not None:
447-
bold_files = [bold_files]
448-
for idx, bold_file in enumerate(bold_files):
449-
func_preproc_wf = init_func_preproc_wf(
450-
bold_file,
451-
has_fieldmap=has_fieldmap,
452-
existing_derivatives=derivatives,
453-
)
454-
# fmt: off
455-
workflow.connect([
456-
(bold_ref_wf, func_preproc_wf, [
457-
('outputnode.epi_ref_file', 'inputnode.bold_ref'),
458-
(
459-
('outputnode.xfm_files', _select_iter_idx, idx),
460-
'inputnode.bold_ref_xfm'),
461-
(
462-
('outputnode.n_dummy', _select_iter_idx, idx),
463-
'inputnode.n_dummy_scans'),
464-
]),
465-
(anat_preproc_wf, func_preproc_wf, [
466-
('outputnode.anat_preproc', 'inputnode.anat_preproc'),
467-
('outputnode.anat_mask', 'inputnode.anat_mask'),
468-
('outputnode.anat_brain', 'inputnode.anat_brain'),
469-
('outputnode.anat_dseg', 'inputnode.anat_dseg'),
470-
('outputnode.anat_aseg', 'inputnode.anat_aseg'),
471-
('outputnode.anat_aparc', 'inputnode.anat_aparc'),
472-
('outputnode.anat_tpms', 'inputnode.anat_tpms'),
473-
('outputnode.template', 'inputnode.template'),
474-
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
475-
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
476-
# Undefined if --fs-no-reconall, but this is safe
477-
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
478-
('outputnode.subject_id', 'inputnode.subject_id'),
479-
('outputnode.t1w2fsnative_xfm', 'inputnode.t1w2fsnative_xfm'),
480-
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
481-
]),
482-
])
483-
# fmt: on
484-
func_preproc_wfs.append(func_preproc_wf)
421+
for bold_file in subject_data['bold']:
422+
func_preproc_wf = init_func_preproc_wf(bold_file, has_fieldmap=has_fieldmap)
423+
if func_preproc_wf is None:
424+
continue
425+
426+
func_preproc_wf.__desc__ = func_pre_desc + (getattr(func_preproc_wf, '__desc__') or '')
427+
# fmt:off
428+
workflow.connect([
429+
(anat_preproc_wf, func_preproc_wf, [
430+
('outputnode.anat_preproc', 'inputnode.anat_preproc'),
431+
('outputnode.anat_mask', 'inputnode.anat_mask'),
432+
('outputnode.anat_brain', 'inputnode.anat_brain'),
433+
('outputnode.anat_dseg', 'inputnode.anat_dseg'),
434+
('outputnode.anat_aseg', 'inputnode.anat_aseg'),
435+
('outputnode.anat_aparc', 'inputnode.anat_aparc'),
436+
('outputnode.anat_tpms', 'inputnode.anat_tpms'),
437+
('outputnode.template', 'inputnode.template'),
438+
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
439+
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
440+
# Undefined if --fs-no-reconall, but this is safe
441+
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
442+
('outputnode.subject_id', 'inputnode.subject_id'),
443+
('outputnode.t1w2fsnative_xfm', 'inputnode.t1w2fsnative_xfm'),
444+
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
445+
]),
446+
])
447+
# fmt:on
448+
func_preproc_wfs.append(func_preproc_wf)
485449

486450
if not has_fieldmap:
487451
config.loggers.workflow.warning(
@@ -506,6 +470,7 @@ def init_single_subject_wf(subject_id, session_id=None):
506470
subject=subject_id,
507471
)
508472
fmap_wf.__desc__ = f"""
473+
509474
Preprocessing of B<sub>0</sub> inhomogeneity mappings
510475
511476
: A total of {len(fmap_estimators)} fieldmaps were found available within the input

nibabies/workflows/bold/base.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from ...interfaces.reports import FunctionalSummary
5858
from ...utils.bids import extract_entities
5959
from ...utils.misc import combine_meepi_source
60+
from .boldref import init_infant_epi_reference_wf
6061

6162
# BOLD workflows
6263
from .confounds import init_bold_confs_wf, init_carpetplot_wf
@@ -127,12 +128,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
127128
LTA-style affine matrix translating from T1w to FreeSurfer-conformed subject space
128129
fsnative2t1w_xfm
129130
LTA-style affine matrix translating from FreeSurfer-conformed subject space to T1w
130-
bold_ref
131-
BOLD reference file
132-
bold_ref_xfm
133-
Transform file in LTA format from bold to reference
134-
n_dummy_scans
135-
Number of nonsteady states at the beginning of the BOLD run
136131
137132
Outputs
138133
-------
@@ -177,6 +172,7 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
177172
178173
"""
179174
from niworkflows.engine.workflows import LiterateWorkflow as Workflow
175+
from niworkflows.interfaces.bold import NonsteadyStatesDetector
180176
from niworkflows.interfaces.nibabel import ApplyMask
181177
from niworkflows.interfaces.utility import DictMerge, KeySelect
182178
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
@@ -244,9 +240,14 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
244240
)
245241

246242
# Find associated sbref, if possible
247-
entities["suffix"] = "sbref"
248-
entities["extension"] = [".nii", ".nii.gz"] # Overwrite extensions
249-
sbref_files = layout.get(scope="raw", return_type="file", **entities)
243+
overrides = {
244+
"suffix": "sbref",
245+
"extension": [".nii", ".nii.gz"],
246+
}
247+
if config.execution.bids_filters:
248+
overrides.update(config.execution.bids_filters.get('sbref', {}))
249+
sb_ents = {**entities, **overrides}
250+
sbref_files = layout.get(return_type="file", **sb_ents)
250251

251252
sbref_msg = f"No single-band-reference found for {os.path.basename(ref_file)}."
252253
if sbref_files and "sbref" in config.workflow.ignore:
@@ -319,10 +320,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
319320
"anat2std_xfm",
320321
"std2anat_xfm",
321322
"template",
322-
# from bold reference workflow
323-
"bold_ref",
324-
"bold_ref_xfm",
325-
"n_dummy_scans",
326323
# from sdcflows (optional)
327324
"fmap",
328325
"fmap_ref",
@@ -514,12 +511,21 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
514511
)
515512
bold_confounds_wf.get_node("inputnode").inputs.t1_transform_flags = [False]
516513

514+
dummy_buffer = pe.Node(niu.IdentityInterface(fields=['n_dummy']), name='dummy_buffer')
515+
if (dummy := config.workflow.dummy_scans) is not None:
516+
dummy_buffer.inputs.n_dummy = dummy
517+
else:
518+
# Detect dummy scans
519+
nss_detector = pe.Node(NonsteadyStatesDetector(), name='nss_detector')
520+
nss_detector.inputs.in_file = ref_file
521+
workflow.connect(nss_detector, 'n_dummy', dummy_buffer, 'n_dummy')
522+
517523
# SLICE-TIME CORRECTION (or bypass) #############################################
518524
if run_stc:
519525
bold_stc_wf = init_bold_stc_wf(name="bold_stc_wf", metadata=metadata)
520526
# fmt:off
521527
workflow.connect([
522-
(inputnode, bold_stc_wf, [('n_dummy_scans', 'inputnode.skip_vols')]),
528+
(dummy_buffer, bold_stc_wf, [('n_dummy', 'inputnode.skip_vols')]),
523529
(select_bold, bold_stc_wf, [("out", 'inputnode.bold_file')]),
524530
(bold_stc_wf, boldbuffer, [('outputnode.stc_file', 'bold_file')]),
525531
])
@@ -577,8 +583,11 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
577583
name="bold_final",
578584
)
579585

580-
# Mask input BOLD reference image
581-
initial_boldref_mask = pe.Node(BrainExtraction(), name="initial_boldref_mask")
586+
# Create a reference image for the bold run
587+
initial_boldref_wf = init_infant_epi_reference_wf(omp_nthreads, is_sbref=bool(sbref_files))
588+
initial_boldref_wf.inputs.inputnode.epi_file = (
589+
pop_file(sbref_files) if sbref_files else ref_file
590+
)
582591

583592
# This final boldref will be calculated after bold_bold_trans_wf, which includes one or more:
584593
# HMC (head motion correction)
@@ -602,8 +611,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
602611
# BOLD buffer has slice-time corrected if it was run, original otherwise
603612
(boldbuffer, bold_split, [('bold_file', 'in_file')]),
604613
# HMC
605-
(inputnode, bold_hmc_wf, [
606-
('bold_ref', 'inputnode.raw_ref_image')]),
614+
(initial_boldref_wf, bold_hmc_wf, [
615+
('outputnode.boldref_file', 'inputnode.raw_ref_image')]),
607616
(validate_bolds, bold_hmc_wf, [
608617
(("out_file", pop_file), 'inputnode.bold_file')]),
609618
(bold_hmc_wf, outputnode, [
@@ -659,8 +668,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
659668
('outputnode.rmsd_file', 'inputnode.rmsd_file')]),
660669
(bold_reg_wf, bold_confounds_wf, [
661670
('outputnode.itk_t1_to_bold', 'inputnode.t1_bold_xform')]),
662-
(inputnode, bold_confounds_wf, [
663-
('n_dummy_scans', 'inputnode.skip_vols')]),
671+
(dummy_buffer, bold_confounds_wf, [
672+
('n_dummy', 'inputnode.skip_vols')]),
664673
(bold_final, bold_confounds_wf, [
665674
('bold', 'inputnode.bold'),
666675
('mask', 'inputnode.bold_mask'),
@@ -672,7 +681,7 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
672681
('outputnode.tcompcor_mask', 'tcompcor_mask'),
673682
]),
674683
# Summary
675-
(inputnode, summary, [('n_dummy_scans', 'algo_dummy_scans')]),
684+
(dummy_buffer, summary, [('n_dummy', 'algo_dummy_scans')]),
676685
(bold_reg_wf, summary, [('outputnode.fallback', 'fallback')]),
677686
(outputnode, summary, [('confounds', 'confounds_file')]),
678687
# Select echo indices for original/validated BOLD files
@@ -874,8 +883,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
874883
('bold_file', 'inputnode.name_source')]),
875884
(bold_hmc_wf, ica_aroma_wf, [
876885
('outputnode.movpar_file', 'inputnode.movpar_file')]),
877-
(inputnode, ica_aroma_wf, [
878-
('n_dummy_scans', 'inputnode.skip_vols')]),
886+
(dummy_buffer, ica_aroma_wf, [
887+
('n_dummy', 'inputnode.skip_vols')]),
879888
(bold_confounds_wf, join, [
880889
('outputnode.confounds_file', 'in_file')]),
881890
(bold_confounds_wf, mrg_conf_metadata,
@@ -1051,9 +1060,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
10511060
("outputnode.bold", "inputnode.in_files"),
10521061
]),
10531062
] if not multiecho else [
1054-
(inputnode, initial_boldref_mask, [('bold_ref', 'in_file')]),
1055-
(initial_boldref_mask, bold_t2s_wf, [
1056-
("out_mask", "inputnode.bold_mask"),
1063+
(initial_boldref_wf, bold_t2s_wf, [
1064+
("outputnode.boldref_mask", "inputnode.bold_mask"),
10571065
]),
10581066
(bold_bold_trans_wf, join_echos, [
10591067
("outputnode.bold", "bold_files"),
@@ -1125,14 +1133,13 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False, existing_derivatives=Non
11251133
("fmap_coeff", "inputnode.fmap_coeff"),
11261134
("fmap_mask", "inputnode.fmap_mask")]),
11271135
(output_select, summary, [("sdc_method", "distortion_correction")]),
1128-
(inputnode, initial_boldref_mask, [('bold_ref', 'in_file')]),
1129-
(inputnode, coeff2epi_wf, [
1130-
("bold_ref", "inputnode.target_ref")]),
1131-
(initial_boldref_mask, coeff2epi_wf, [
1132-
("out_mask", "inputnode.target_mask")]), # skull-stripped brain
1136+
(initial_boldref_wf, coeff2epi_wf, [
1137+
("outputnode.boldref_file", "inputnode.target_ref")]),
1138+
(initial_boldref_wf, coeff2epi_wf, [
1139+
("outputnode.boldref_mask", "inputnode.target_mask")]), # skull-stripped brain
11331140
(coeff2epi_wf, unwarp_wf, [
11341141
("outputnode.fmap_coeff", "inputnode.fmap_coeff")]),
1135-
(inputnode, sdc_report, [("bold_ref", "before")]),
1142+
(initial_boldref_wf, sdc_report, [("outputnode.boldref_file", "before")]),
11361143
(bold_hmc_wf, unwarp_wf, [
11371144
("outputnode.xforms", "inputnode.hmc_xforms")]),
11381145
(bold_split, unwarp_wf, [

nibabies/workflows/bold/boldref.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import nipype.interfaces.utility as niu
2+
import nipype.pipeline.engine as pe
3+
4+
5+
def init_infant_epi_reference_wf(
6+
omp_nthreads: int,
7+
is_sbref: bool = False,
8+
start_frame: int = 17,
9+
name: str = 'infant_epi_reference_wf',
10+
) -> pe.Workflow:
11+
"""
12+
Workflow to generate a reference map from one or more infant EPI images.
13+
14+
If any single-band references are provided, the reference map will be calculated from those.
15+
16+
If no single-band references are provided, the BOLD files are used.
17+
To account for potential increased motion on the start of image acquisition, this
18+
workflow discards a bigger chunk of the initial frames.
19+
20+
Parameters
21+
----------
22+
omp_nthreads
23+
Maximum number of threads an individual process may use
24+
has_sbref
25+
A single-band reference is provided.
26+
start_frame
27+
BOLD frame to start creating the reference map from. Any earlier frames are discarded.
28+
29+
Inputs
30+
------
31+
bold_file
32+
BOLD EPI file
33+
sbref_file
34+
single-band reference EPI
35+
36+
Outputs
37+
-------
38+
boldref_file
39+
The generated reference map
40+
boldref_mask
41+
Binary brain mask of the ``boldref_file``
42+
boldref_xfm
43+
Rigid-body transforms in LTA format
44+
45+
"""
46+
from niworkflows.workflows.epi.refmap import init_epi_reference_wf
47+
from sdcflows.interfaces.brainmask import BrainExtraction
48+
49+
wf = pe.Workflow(name=name)
50+
51+
inputnode = pe.Node(
52+
niu.IdentityInterface(fields=['epi_file']),
53+
name='inputnode',
54+
)
55+
outputnode = pe.Node(
56+
niu.IdentityInterface(fields=['boldref_file', 'boldref_mask']),
57+
name='outputnode',
58+
)
59+
60+
epi_reference_wf = init_epi_reference_wf(omp_nthreads)
61+
62+
boldref_mask = pe.Node(BrainExtraction(), name='boldref_mask')
63+
64+
# fmt:off
65+
wf.connect([
66+
(inputnode, epi_reference_wf, [('epi_file', 'inputnode.in_files')]),
67+
(epi_reference_wf, boldref_mask, [('outputnode.epi_ref_file', 'in_file')]),
68+
(epi_reference_wf, outputnode, [('outputnode.epi_ref_file', 'boldref_file')]),
69+
(boldref_mask, outputnode, [('out_mask', 'boldref_mask')]),
70+
])
71+
# fmt:on
72+
if not is_sbref:
73+
select_frames = pe.Node(
74+
niu.Function(function=_select_frames, output_names=['t_masks']),
75+
name='select_frames',
76+
)
77+
select_frames.inputs.start_frame = start_frame
78+
# fmt:off
79+
wf.connect([
80+
(inputnode, select_frames, [('epi_file', 'in_file')]),
81+
(select_frames, epi_reference_wf, [('t_masks', 'inputnode.t_masks')]),
82+
])
83+
# fmt:on
84+
return wf
85+
86+
87+
def _select_frames(in_file: str, start_frame: int) -> list:
88+
import nibabel as nb
89+
import numpy as np
90+
91+
img = nb.load(in_file)
92+
img_len = img.shape[3]
93+
if start_frame >= img_len:
94+
start_frame = img_len - 1
95+
t_mask = np.array([False] * img_len, dtype=bool)
96+
t_mask[start_frame:] = True
97+
return list(t_mask)

0 commit comments

Comments
 (0)