Skip to content

Commit c476e31

Browse files
authored
rf: Separate fieldmap registration from coreg ref generation (#3467)
The goal of this PR was initially to get SDC reports added to the BOLD section of the reports when reusing minimal derivatives. It ended up splitting fieldmap-to-boldref registration into a coherent section of its own.
2 parents 1f3d29b + b275096 commit c476e31

File tree

1 file changed

+89
-76
lines changed

1 file changed

+89
-76
lines changed

fmriprep/workflows/bold/fit.py

Lines changed: 89 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,7 @@ def init_bold_fit_wf(
349349
)
350350

351351
func_fit_reports_wf = init_func_fit_reports_wf(
352-
# TODO: Enable sdc report even if we find coregref
353-
sdc_correction=not (coreg_boldref or fieldmap_id is None),
352+
sdc_correction=fieldmap_id is not None,
354353
freesurfer=config.workflow.run_reconall,
355354
output_dir=config.execution.fmriprep_dir,
356355
)
@@ -360,6 +359,7 @@ def init_bold_fit_wf(
360359
('boldref', 'hmc_boldref'),
361360
('dummy_scans', 'dummy_scans'),
362361
]),
362+
(hmcref_buffer, fmapref_buffer, [('boldref', 'boldref_files')]),
363363
(regref_buffer, outputnode, [
364364
('boldref', 'coreg_boldref'),
365365
('boldmask', 'bold_mask'),
@@ -462,10 +462,82 @@ def init_bold_fit_wf(
462462
else:
463463
config.loggers.workflow.info('Found motion correction transforms - skipping Stage 2')
464464

465-
# Stage 3: Create coregistration reference
466-
# Fieldmap correction only happens during fit if this stage is needed
465+
# Stage 3: Register fieldmap to boldref and reconstruct in BOLD space
466+
if fieldmap_id:
467+
config.loggers.workflow.info('Stage 3: Adding fieldmap reconstruction workflow')
468+
fmap_select = pe.Node(
469+
KeySelect(
470+
fields=['fmap_ref', 'fmap_coeff', 'fmap_mask', 'sdc_method'],
471+
key=fieldmap_id,
472+
),
473+
name='fmap_select',
474+
run_without_submitting=True,
475+
)
476+
477+
boldref_fmap = pe.Node(ReconstructFieldmap(inverse=[True]), name='boldref_fmap', mem_gb=1)
478+
479+
workflow.connect([
480+
(inputnode, fmap_select, [
481+
('fmap_ref', 'fmap_ref'),
482+
('fmap_coeff', 'fmap_coeff'),
483+
('fmap_mask', 'fmap_mask'),
484+
('sdc_method', 'sdc_method'),
485+
('fmap_id', 'keys'),
486+
]),
487+
(fmapref_buffer, boldref_fmap, [('out', 'target_ref_file')]),
488+
(fmapreg_buffer, boldref_fmap, [('boldref2fmap_xfm', 'transforms')]),
489+
(fmap_select, boldref_fmap, [
490+
('fmap_coeff', 'in_coeffs'),
491+
('fmap_ref', 'fmap_ref_file'),
492+
]),
493+
(fmap_select, func_fit_reports_wf, [('fmap_ref', 'inputnode.fmap_ref')]),
494+
(fmap_select, summary, [('sdc_method', 'distortion_correction')]),
495+
(fmapref_buffer, func_fit_reports_wf, [('out', 'inputnode.sdc_boldref')]),
496+
(fmapreg_buffer, func_fit_reports_wf, [
497+
('boldref2fmap_xfm', 'inputnode.boldref2fmap_xfm'),
498+
]),
499+
(boldref_fmap, func_fit_reports_wf, [('out_file', 'inputnode.fieldmap')]),
500+
]) # fmt:skip
501+
502+
if not boldref2fmap_xform:
503+
config.loggers.workflow.info('Stage 3: Registering fieldmap to boldref')
504+
fmapreg_wf = init_coeff2epi_wf(
505+
debug='fieldmaps' in config.execution.debug,
506+
omp_nthreads=config.nipype.omp_nthreads,
507+
sloppy=config.execution.sloppy,
508+
name='fmapreg_wf',
509+
)
510+
511+
itk_mat2txt = pe.Node(ConcatenateXFMs(out_fmt='itk'), name='itk_mat2txt')
512+
513+
ds_fmapreg_wf = init_ds_registration_wf(
514+
bids_root=layout.root,
515+
output_dir=config.execution.fmriprep_dir,
516+
source='boldref',
517+
dest=fieldmap_id.replace('_', ''),
518+
name='ds_fmapreg_wf',
519+
)
520+
ds_fmapreg_wf.inputs.inputnode.source_files = [bold_file]
521+
522+
workflow.connect([
523+
(fmap_select, fmapreg_wf, [
524+
('fmap_ref', 'inputnode.fmap_ref'),
525+
('fmap_mask', 'inputnode.fmap_mask'),
526+
]),
527+
(fmapreg_wf, itk_mat2txt, [('outputnode.target2fmap_xfm', 'in_xfms')]),
528+
(itk_mat2txt, ds_fmapreg_wf, [('out_xfm', 'inputnode.xform')]),
529+
(ds_fmapreg_wf, fmapreg_buffer, [('outputnode.xform', 'boldref2fmap_xfm')]),
530+
]) # fmt:skip
531+
else:
532+
config.loggers.workflow.info(
533+
'Stage 3: Found fieldmap transform - skipping registration'
534+
)
535+
else:
536+
config.loggers.workflow.info('No fieldmap correction - skipping Stage 3')
537+
538+
# Stage 4: Create coregistration reference
467539
if not coreg_boldref:
468-
config.loggers.workflow.info('Stage 3: Adding coregistration boldref workflow')
540+
config.loggers.workflow.info('Stage 4: Adding coregistration boldref workflow')
469541

470542
# Select initial boldref, enhance contrast, and generate mask
471543
if sbref_files and nb.load(sbref_files[0]).ndim > 3:
@@ -492,63 +564,15 @@ def init_bold_fit_wf(
492564
ds_boldmask_wf.inputs.inputnode.source_files = [bold_file]
493565

494566
workflow.connect([
495-
(hmcref_buffer, fmapref_buffer, [('boldref', 'boldref_files')]),
496567
(fmapref_buffer, enhance_boldref_wf, [('out', 'inputnode.in_file')]),
497568
(hmc_boldref_source_buffer, ds_coreg_boldref_wf, [
498569
('in_file', 'inputnode.source_files'),
499570
]),
500571
(ds_coreg_boldref_wf, regref_buffer, [('outputnode.boldref', 'boldref')]),
501572
(ds_boldmask_wf, regref_buffer, [('outputnode.boldmask', 'boldmask')]),
502-
(fmapref_buffer, func_fit_reports_wf, [('out', 'inputnode.sdc_boldref')]),
503573
]) # fmt:skip
504574

505575
if fieldmap_id:
506-
fmap_select = pe.Node(
507-
KeySelect(
508-
fields=['fmap_ref', 'fmap_coeff', 'fmap_mask', 'sdc_method'],
509-
key=fieldmap_id,
510-
),
511-
name='fmap_select',
512-
run_without_submitting=True,
513-
)
514-
515-
if not boldref2fmap_xform:
516-
fmapreg_wf = init_coeff2epi_wf(
517-
debug='fieldmaps' in config.execution.debug,
518-
omp_nthreads=config.nipype.omp_nthreads,
519-
sloppy=config.execution.sloppy,
520-
name='fmapreg_wf',
521-
)
522-
523-
itk_mat2txt = pe.Node(ConcatenateXFMs(out_fmt='itk'), name='itk_mat2txt')
524-
525-
ds_fmapreg_wf = init_ds_registration_wf(
526-
bids_root=layout.root,
527-
output_dir=config.execution.fmriprep_dir,
528-
source='boldref',
529-
dest=fieldmap_id.replace('_', ''),
530-
name='ds_fmapreg_wf',
531-
)
532-
ds_fmapreg_wf.inputs.inputnode.source_files = [bold_file]
533-
534-
workflow.connect([
535-
(enhance_boldref_wf, fmapreg_wf, [
536-
('outputnode.bias_corrected_file', 'inputnode.target_ref'),
537-
('outputnode.mask_file', 'inputnode.target_mask'),
538-
]),
539-
(fmap_select, fmapreg_wf, [
540-
('fmap_ref', 'inputnode.fmap_ref'),
541-
('fmap_mask', 'inputnode.fmap_mask'),
542-
]),
543-
(fmapreg_wf, itk_mat2txt, [('outputnode.target2fmap_xfm', 'in_xfms')]),
544-
(itk_mat2txt, ds_fmapreg_wf, [('out_xfm', 'inputnode.xform')]),
545-
(ds_fmapreg_wf, fmapreg_buffer, [('outputnode.xform', 'boldref2fmap_xfm')]),
546-
]) # fmt:skip
547-
548-
boldref_fmap = pe.Node(
549-
ReconstructFieldmap(inverse=[True]), name='boldref_fmap', mem_gb=1
550-
)
551-
552576
distortion_params = pe.Node(
553577
DistortionParameters(
554578
metadata=metadata,
@@ -569,19 +593,6 @@ def init_bold_fit_wf(
569593
skullstrip_bold_wf = init_skullstrip_bold_wf()
570594

571595
workflow.connect([
572-
(inputnode, fmap_select, [
573-
('fmap_ref', 'fmap_ref'),
574-
('fmap_coeff', 'fmap_coeff'),
575-
('fmap_mask', 'fmap_mask'),
576-
('sdc_method', 'sdc_method'),
577-
('fmap_id', 'keys'),
578-
]),
579-
(fmapref_buffer, boldref_fmap, [('out', 'target_ref_file')]),
580-
(fmapreg_buffer, boldref_fmap, [('boldref2fmap_xfm', 'transforms')]),
581-
(fmap_select, boldref_fmap, [
582-
('fmap_coeff', 'in_coeffs'),
583-
('fmap_ref', 'fmap_ref_file'),
584-
]),
585596
(fmapref_buffer, unwarp_boldref, [('out', 'ref_file')]),
586597
(enhance_boldref_wf, unwarp_boldref, [
587598
('outputnode.bias_corrected_file', 'in_file'),
@@ -600,13 +611,15 @@ def init_bold_fit_wf(
600611
(skullstrip_bold_wf, ds_boldmask_wf, [
601612
('outputnode.mask_file', 'inputnode.boldmask'),
602613
]),
603-
(fmap_select, func_fit_reports_wf, [('fmap_ref', 'inputnode.fmap_ref')]),
604-
(fmap_select, summary, [('sdc_method', 'distortion_correction')]),
605-
(fmapreg_buffer, func_fit_reports_wf, [
606-
('boldref2fmap_xfm', 'inputnode.boldref2fmap_xfm'),
607-
]),
608-
(boldref_fmap, func_fit_reports_wf, [('out_file', 'inputnode.fieldmap')]),
609614
]) # fmt:skip
615+
616+
if not boldref2fmap_xform:
617+
workflow.connect([
618+
(enhance_boldref_wf, fmapreg_wf, [
619+
('outputnode.bias_corrected_file', 'inputnode.target_ref'),
620+
('outputnode.mask_file', 'inputnode.target_mask'),
621+
]),
622+
]) # fmt:skip
610623
else:
611624
workflow.connect([
612625
(enhance_boldref_wf, ds_coreg_boldref_wf, [
@@ -617,7 +630,7 @@ def init_bold_fit_wf(
617630
]),
618631
]) # fmt:skip
619632
else:
620-
config.loggers.workflow.info('Found coregistration reference - skipping Stage 3')
633+
config.loggers.workflow.info('Found coregistration reference - skipping Stage 4')
621634

622635
# TODO: Allow precomputed bold masks to be passed
623636
# Also needs consideration for how it interacts above
@@ -628,6 +641,7 @@ def init_bold_fit_wf(
628641
]) # fmt:skip
629642

630643
if not boldref2anat_xform:
644+
config.loggers.workflow.info('Stage 5: Adding coregistration workflow')
631645
use_bbr = (
632646
True
633647
if 'bbr' in config.workflow.force
@@ -654,7 +668,6 @@ def init_bold_fit_wf(
654668
name='ds_boldreg_wf',
655669
)
656670

657-
# fmt:off
658671
workflow.connect([
659672
(inputnode, bold_reg_wf, [
660673
('t1w_preproc', 'inputnode.t1w_preproc'),
@@ -671,9 +684,9 @@ def init_bold_fit_wf(
671684
(bold_reg_wf, ds_boldreg_wf, [('outputnode.itk_bold_to_t1', 'inputnode.xform')]),
672685
(ds_boldreg_wf, outputnode, [('outputnode.xform', 'boldref2anat_xfm')]),
673686
(bold_reg_wf, summary, [('outputnode.fallback', 'fallback')]),
674-
])
675-
# fmt:on
687+
]) # fmt:skip
676688
else:
689+
config.loggers.workflow.info('Found coregistration transform - skipping Stage 5')
677690
outputnode.inputs.boldref2anat_xfm = boldref2anat_xform
678691

679692
return workflow

0 commit comments

Comments
 (0)