Skip to content

Commit 6e0b9ae

Browse files
committed
rf: Allow callers to pass a pre-computed epi2anat transform for SyN
1 parent 5935342 commit 6e0b9ae

File tree

2 files changed

+66
-29
lines changed

2 files changed

+66
-29
lines changed

sdcflows/workflows/fit/syn.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def init_syn_preprocessing_wf(
325325
debug=False,
326326
name="syn_preprocessing_wf",
327327
omp_nthreads=1,
328+
coregister=True,
328329
auto_bold_nss=False,
329330
t1w_inversion=None,
330331
sd_prior=True,
@@ -348,6 +349,9 @@ def init_syn_preprocessing_wf(
348349
Name for this workflow
349350
omp_nthreads : :obj:`int`
350351
Parallelize internal tasks across the number of CPUs given by this option.
352+
coregister: :class:`bool`
353+
Run BOLD-to-Anat coregistration. If set to ``False``, ``epi2anat_xfm`` must be
354+
provided.
351355
auto_bold_nss : :obj:`bool`
352356
Set up the reference workflow to automatically execute nonsteady states detection
353357
of BOLD images.
@@ -417,6 +421,7 @@ def init_syn_preprocessing_wf(
417421
"in_anat",
418422
"mask_anat",
419423
"std2anat_xfm",
424+
"epi2anat_xfm",
420425
]
421426
),
422427
name="inputnode",
@@ -464,28 +469,44 @@ def init_syn_preprocessing_wf(
464469
DenoiseImage(copy_header=True), name="ref_anat", n_procs=omp_nthreads
465470
)
466471

467-
epi2anat = pe.Node(
468-
Registration(from_file=data.load("affine.json")),
469-
name="epi2anat",
470-
n_procs=omp_nthreads,
471-
)
472-
epi2anat.inputs.output_warped_image = debug
473-
epi2anat.inputs.output_inverse_warped_image = debug
474-
if debug:
475-
epi2anat.inputs.args = "--write-interval-volumes 5"
476-
477-
def _remove_first_mask(in_file):
478-
if not isinstance(in_file, list):
479-
in_file = [in_file]
480-
481-
in_file.insert(0, "NULL")
482-
return in_file
483-
484472
anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk")
485473
epi_dilmsk = pe.Node(BinaryDilation(), name="epi_dilmsk")
486474

487475
sampling_ref = pe.Node(GenerateSamplingReference(), name="sampling_ref")
488476

477+
if coregister:
478+
epi2anat = pe.Node(
479+
Registration(from_file=data.load("affine.json")),
480+
name="epi2anat",
481+
n_procs=omp_nthreads,
482+
)
483+
epi2anat.inputs.output_warped_image = debug
484+
epi2anat.inputs.output_inverse_warped_image = debug
485+
if debug:
486+
epi2anat.inputs.args = "--write-interval-volumes 5"
487+
488+
def _remove_first_mask(in_file):
489+
if not isinstance(in_file, list):
490+
in_file = [in_file]
491+
492+
in_file.insert(0, "NULL")
493+
return in_file
494+
495+
workflow.connect([
496+
(ref_anat, epi2anat, [("output_image", "fixed_image")]),
497+
(anat_dilmsk, epi2anat, [("out_file", "fixed_image_masks")]),
498+
(deob_epi, epi2anat, [("out_file", "moving_image")]),
499+
(epi_dilmsk, epi2anat, [
500+
(("out_file", _remove_first_mask), "moving_image_masks")]),
501+
(epi2anat, anat2epi, [("forward_transforms", "transforms")]),
502+
(epi2anat, mask2epi, [("forward_transforms", "transforms")]),
503+
]) # fmt:skip
504+
else:
505+
workflow.connect([
506+
(inputnode, anat2epi, [("epi2anat_xfm", "transforms")]),
507+
(inputnode, mask2epi, [("epi2anat_xfm", "transforms")]),
508+
])
509+
489510
if sd_prior:
490511
# Mapping & preparing prior knowledge
491512
# Concatenate transform files:
@@ -511,12 +532,20 @@ def _remove_first_mask(in_file):
511532

512533
workflow.connect([
513534
(inputnode, transform_list, [("std2anat_xfm", "in2")]),
514-
(epi2anat, transform_list, [("forward_transforms", "in1")]),
515535
(transform_list, prior2epi, [("out", "transforms")]),
516536
(sampling_ref, prior2epi, [("out_file", "reference_image")]),
517537
(prior2epi, outputnode, [("output_image", "sd_prior")]),
518538
]) # fmt:skip
519539

540+
if coregister:
541+
workflow.connect([
542+
(epi2anat, transform_list, [("forward_transforms", "in1")]),
543+
]) # fmt:skip
544+
else:
545+
workflow.connect([
546+
(inputnode, transform_list, [("epi2anat_xfm", "in1")]),
547+
])
548+
520549
else:
521550
# no prior to be used
522551
# MG: Future goal is to allow using alternative mappings
@@ -536,16 +565,9 @@ def _remove_first_mask(in_file):
536565
(clip_anat, ref_anat, [("out_file", "input_image")]),
537566
(deob_epi, epi_brain, [("out_file", "in_file")]),
538567
(epi_brain, epi_dilmsk, [("out_mask", "in_file")]),
539-
(ref_anat, epi2anat, [("output_image", "fixed_image")]),
540-
(anat_dilmsk, epi2anat, [("out_file", "fixed_image_masks")]),
541-
(deob_epi, epi2anat, [("out_file", "moving_image")]),
542-
(epi_dilmsk, epi2anat, [
543-
(("out_file", _remove_first_mask), "moving_image_masks")]),
544568
(deob_epi, sampling_ref, [("out_file", "fixed_image")]),
545569
(ref_anat, anat2epi, [("output_image", "input_image")]),
546-
(epi2anat, anat2epi, [("forward_transforms", "transforms")]),
547570
(sampling_ref, anat2epi, [("out_file", "reference_image")]),
548-
(epi2anat, mask2epi, [("forward_transforms", "transforms")]),
549571
(sampling_ref, mask2epi, [("out_file", "reference_image")]),
550572
(mask2epi, mask_dtype, [("output_image", "in_file")]),
551573
(anat2epi, outputnode, [("output_image", "anat_ref")]),

sdcflows/workflows/fit/tests/test_syn.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#
2323
"""Test fieldmap-less SDC-SyN."""
2424
import json
25+
26+
import acres
2527
import pytest
2628
from nipype.pipeline import engine as pe
2729

@@ -30,8 +32,15 @@
3032

3133
@pytest.mark.veryslow
3234
@pytest.mark.slow
33-
@pytest.mark.parametrize("sd_prior", [True, False])
34-
def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
35+
@pytest.mark.parametrize(
36+
("n_bold", "coregister", "sd_prior"),
37+
[
38+
(1, True, True),
39+
# Switch to False once we have a transform in tests/data
40+
(2, True, False),
41+
]
42+
)
43+
def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, n_bold, coregister, sd_prior):
3544
"""Build and run an SDC-SyN workflow."""
3645
derivs_path = datadir / "ds000054" / "derivatives"
3746
smriprep = derivs_path / "smriprep-0.6" / "sub-100185" / "anat"
@@ -43,6 +52,7 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
4352
debug=sloppy_mode,
4453
auto_bold_nss=True,
4554
sd_prior=sd_prior,
55+
coregister=coregister,
4656
)
4757
prep_wf.inputs.inputnode.in_epis = [
4858
str(
@@ -59,10 +69,10 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
5969
/ "func"
6070
/ "sub-100185_task-machinegame_run-02_bold.nii.gz"
6171
),
62-
]
72+
][:n_bold]
6373
prep_wf.inputs.inputnode.in_meta = [
6474
json.loads((datadir / "ds000054" / "task-machinegame_bold.json").read_text()),
65-
] * 2
75+
] * n_bold
6676
prep_wf.inputs.inputnode.std2anat_xfm = str(
6777
smriprep / "sub-100185_from-MNI152NLin2009cAsym_to-T1w_mode-image_xfm.h5"
6878
)
@@ -72,6 +82,11 @@ def test_syn_wf(tmpdir, datadir, workdir, outdir, sloppy_mode, sd_prior):
7282
prep_wf.inputs.inputnode.mask_anat = str(
7383
smriprep / "sub-100185_desc-brain_mask.nii.gz"
7484
)
85+
if not coregister:
86+
test_data = acres.Loader('sdcflows.tests')
87+
prep_wf.inputs.inputnode.epi_ref = str(
88+
test_data('data/anat2epi_xfm.txt')
89+
)
7590

7691
syn_wf = init_syn_sdc_wf(
7792
debug=sloppy_mode,

0 commit comments

Comments
 (0)