Skip to content

Commit b513e98

Browse files
committed
rf: Allow callers to pass a pre-computed epi2anat transform for SyN
1 parent 6627730 commit b513e98

File tree

2 files changed

+66
-29
lines changed

2 files changed

+66
-29
lines changed

sdcflows/workflows/fit/syn.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def init_syn_preprocessing_wf(
342342
debug=False,
343343
name="syn_preprocessing_wf",
344344
omp_nthreads=1,
345+
coregister=True,
345346
auto_bold_nss=False,
346347
t1w_inversion=None,
347348
sd_prior=True,
@@ -365,6 +366,9 @@ def init_syn_preprocessing_wf(
365366
Name for this workflow
366367
omp_nthreads : :obj:`int`
367368
Parallelize internal tasks across the number of CPUs given by this option.
369+
coregister: :class:`bool`
370+
Run BOLD-to-Anat coregistration. If set to ``False``, ``epi2anat_xfm`` must be
371+
provided.
368372
auto_bold_nss : :obj:`bool`
369373
Set up the reference workflow to automatically execute nonsteady states detection
370374
of BOLD images.
@@ -434,6 +438,7 @@ def init_syn_preprocessing_wf(
434438
"in_anat",
435439
"mask_anat",
436440
"std2anat_xfm",
441+
"epi2anat_xfm",
437442
]
438443
),
439444
name="inputnode",
@@ -481,28 +486,44 @@ def init_syn_preprocessing_wf(
481486
DenoiseImage(copy_header=True), name="ref_anat", n_procs=omp_nthreads
482487
)
483488

484-
epi2anat = pe.Node(
485-
Registration(from_file=data.load("affine.json")),
486-
name="epi2anat",
487-
n_procs=omp_nthreads,
488-
)
489-
epi2anat.inputs.output_warped_image = debug
490-
epi2anat.inputs.output_inverse_warped_image = debug
491-
if debug:
492-
epi2anat.inputs.args = "--write-interval-volumes 5"
493-
494-
def _remove_first_mask(in_file):
495-
if not isinstance(in_file, list):
496-
in_file = [in_file]
497-
498-
in_file.insert(0, "NULL")
499-
return in_file
500-
501489
anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk")
502490
epi_dilmsk = pe.Node(BinaryDilation(), name="epi_dilmsk")
503491

504492
sampling_ref = pe.Node(GenerateSamplingReference(), name="sampling_ref")
505493

494+
if coregister:
495+
epi2anat = pe.Node(
496+
Registration(from_file=data.load("affine.json")),
497+
name="epi2anat",
498+
n_procs=omp_nthreads,
499+
)
500+
epi2anat.inputs.output_warped_image = debug
501+
epi2anat.inputs.output_inverse_warped_image = debug
502+
if debug:
503+
epi2anat.inputs.args = "--write-interval-volumes 5"
504+
505+
def _remove_first_mask(in_file):
506+
if not isinstance(in_file, list):
507+
in_file = [in_file]
508+
509+
in_file.insert(0, "NULL")
510+
return in_file
511+
512+
workflow.connect([
513+
(ref_anat, epi2anat, [("output_image", "fixed_image")]),
514+
(anat_dilmsk, epi2anat, [("out_file", "fixed_image_masks")]),
515+
(deob_epi, epi2anat, [("out_file", "moving_image")]),
516+
(epi_dilmsk, epi2anat, [
517+
(("out_file", _remove_first_mask), "moving_image_masks")]),
518+
(epi2anat, anat2epi, [("forward_transforms", "transforms")]),
519+
(epi2anat, mask2epi, [("forward_transforms", "transforms")]),
520+
]) # fmt:skip
521+
else:
522+
workflow.connect([
523+
(inputnode, anat2epi, [("epi2anat_xfm", "transforms")]),
524+
(inputnode, mask2epi, [("epi2anat_xfm", "transforms")]),
525+
])
526+
506527
if sd_prior:
507528
# Mapping & preparing prior knowledge
508529
# Concatenate transform files:
@@ -528,12 +549,20 @@ def _remove_first_mask(in_file):
528549

529550
workflow.connect([
530551
(inputnode, transform_list, [("std2anat_xfm", "in2")]),
531-
(epi2anat, transform_list, [("forward_transforms", "in1")]),
532552
(transform_list, prior2epi, [("out", "transforms")]),
533553
(sampling_ref, prior2epi, [("out_file", "reference_image")]),
534554
(prior2epi, outputnode, [("output_image", "sd_prior")]),
535555
]) # fmt:skip
536556

557+
if coregister:
558+
workflow.connect([
559+
(epi2anat, transform_list, [("forward_transforms", "in1")]),
560+
]) # fmt:skip
561+
else:
562+
workflow.connect([
563+
(inputnode, transform_list, [("epi2anat_xfm", "in1")]),
564+
])
565+
537566
else:
538567
# no prior to be used
539568
# MG: Future goal is to allow using alternative mappings
@@ -553,16 +582,9 @@ def _remove_first_mask(in_file):
553582
(clip_anat, ref_anat, [("out_file", "input_image")]),
554583
(deob_epi, epi_brain, [("out_file", "in_file")]),
555584
(epi_brain, epi_dilmsk, [("out_mask", "in_file")]),
556-
(ref_anat, epi2anat, [("output_image", "fixed_image")]),
557-
(anat_dilmsk, epi2anat, [("out_file", "fixed_image_masks")]),
558-
(deob_epi, epi2anat, [("out_file", "moving_image")]),
559-
(epi_dilmsk, epi2anat, [
560-
(("out_file", _remove_first_mask), "moving_image_masks")]),
561585
(deob_epi, sampling_ref, [("out_file", "fixed_image")]),
562586
(ref_anat, anat2epi, [("output_image", "input_image")]),
563-
(epi2anat, anat2epi, [("forward_transforms", "transforms")]),
564587
(sampling_ref, anat2epi, [("out_file", "reference_image")]),
565-
(epi2anat, mask2epi, [("forward_transforms", "transforms")]),
566588
(sampling_ref, mask2epi, [("out_file", "reference_image")]),
567589
(mask2epi, mask_dtype, [("output_image", "in_file")]),
568590
(anat2epi, outputnode, [("output_image", "anat_ref")]),

sdcflows/workflows/fit/tests/test_syn.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#
2323
"""Test fieldmap-less SDC-SyN."""
2424
import json
25+
26+
import acres
2527
import pytest
2628
from nipype.pipeline import engine as pe
2729

@@ -30,8 +32,15 @@
3032

3133
@pytest.mark.veryslow
3234
@pytest.mark.slow
33-
@pytest.mark.parametrize("sd_prior", [True, False])
34-
def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
35+
@pytest.mark.parametrize(
36+
("n_bold", "coregister", "sd_prior"),
37+
[
38+
(1, True, True),
39+
# Switch to False once we have a transform in tests/data
40+
(2, True, False),
41+
]
42+
)
43+
def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, n_bold, coregister, sd_prior):
3544
"""Build and run an SDC-SyN workflow."""
3645
derivs_path = datadir / "ds000054" / "derivatives"
3746
smriprep = derivs_path / "smriprep-0.6" / "sub-100185" / "anat"
@@ -43,6 +52,7 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
4352
debug=sloppy_mode,
4453
auto_bold_nss=True,
4554
sd_prior=sd_prior,
55+
coregister=coregister,
4656
)
4757
prep_wf.inputs.inputnode.in_epis = [
4858
str(
@@ -59,10 +69,10 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
5969
/ "func"
6070
/ "sub-100185_task-machinegame_run-02_bold.nii.gz"
6171
),
62-
]
72+
][:n_bold]
6373
prep_wf.inputs.inputnode.in_meta = [
6474
json.loads((datadir / "ds000054" / "task-machinegame_bold.json").read_text()),
65-
] * 2
75+
] * n_bold
6676
prep_wf.inputs.inputnode.std2anat_xfm = str(
6777
smriprep / "sub-100185_from-MNI152NLin2009cAsym_to-T1w_mode-image_xfm.h5"
6878
)
@@ -72,6 +82,11 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
7282
prep_wf.inputs.inputnode.mask_anat = str(
7383
smriprep / "sub-100185_desc-brain_mask.nii.gz"
7484
)
85+
if not coregister:
86+
test_data = acres.Loader('sdcflows.tests')
87+
prep_wf.inputs.inputnode.epi_ref = str(
88+
test_data('data/anat2epi_xfm.txt')
89+
)
7590

7691
syn_wf = init_syn_sdc_wf(
7792
debug=sloppy_mode,

0 commit comments

Comments
 (0)