Skip to content

Commit adb5c57

Browse files
committed
RF: Incorporate precomputed mask into coregistration workflow
1 parent 4dc12c5 commit adb5c57

File tree

1 file changed

+75
-40
lines changed

1 file changed

+75
-40
lines changed

nibabies/workflows/anatomical/registration.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Within-baby registration of a T1w into a T2w image."""
4+
from typing import Optional
5+
46
from nipype.interfaces import utility as niu
57
from nipype.pipeline import engine as pe
68
from pkg_resources import resource_filename as pkgr_fn
79

810

911
def init_coregistration_wf(
1012
*,
11-
bspline_fitting_distance=200,
12-
mem_gb=3.0,
13-
name="coregistration_wf",
14-
omp_nthreads=None,
15-
sloppy=False,
16-
debug=False,
13+
bspline_fitting_distance: int = 200,
14+
mem_gb: float = 3.0,
15+
omp_nthreads: Optional[int] = None,
16+
sloppy: bool = False,
17+
debug: bool = False,
18+
precomputed_mask: bool = False,
19+
name: str = "coregistration_wf",
1720
):
1821
"""
1922
Set-up a T2w-to-T1w within-baby co-registration framework.
@@ -49,29 +52,35 @@ def init_coregistration_wf(
4952
Run in *sloppy* mode.
5053
debug : :obj:`bool`
5154
Produce intermediate registration files
55+
precomputed_mask : :obj:`bool`
56+
A precomputed mask for the T1w is available. In this case, generate a
57+
quick mask to assist in coregistration, but use the precomputed mask
58+
as the final output.
5259
5360
5461
Inputs
5562
------
5663
in_t1w : :obj:`str`
57-
The unprocessed input T1w image.
58-
in_t2w_preproc : :obj:`str`
59-
The preprocessed input T2w image, from the brain extraction workflow.
64+
The preprocessed input T1w image (Denoising/INU/Clipping)
65+
in_t2w : :obj:`str`
66+
The preprocessed input T2w image (Denoising/INU/Clipping)
6067
in_mask : :obj:`str`
61-
The brainmask, as obtained in T2w space.
68+
The brainmask.
69+
If `precomputed_mask` is False, will be in T2w space.
70+
If `precomputed_mask` is True, will be in T1w space.
6271
in_probmap : :obj:`str`
6372
The probabilistic brainmask, as obtained in T2w space.
6473
6574
Outputs
6675
-------
67-
t1w_preproc : :obj:`str`
68-
The preprocessed T1w image (INU and clipping).
69-
t2w_preproc : :obj:`str`
76+
t1w_coreg : :obj:`str`
77+
The preprocessed T1w image (INU and clipping), in its native space.
78+
t2w_coreg : :obj:`str`
7079
The preprocessed T2w image (INU and clipping), aligned into the T1w's space.
7180
t1w_brain : :obj:`str`
7281
The preprocessed, brain-extracted T1w image.
7382
t1w_mask : :obj:`str`
74-
The binary brainmask projected from the T2w.
83+
The binary brainmask in T1w space.
7584
t1w2t2w_xfm : :obj:`str`
7685
The T1w-to-T2w mapping.
7786
@@ -81,10 +90,12 @@ def init_coregistration_wf(
8190
from niworkflows.interfaces.fixes import FixHeaderRegistration as Registration
8291
from niworkflows.interfaces.nibabel import ApplyMask, Binarize, BinaryDilation
8392

93+
from nibabies.utils.misc import get_file
94+
8495
workflow = pe.Workflow(name)
8596

8697
inputnode = pe.Node(
87-
niu.IdentityInterface(fields=["in_t1w", "in_t2w_preproc", "in_mask", "in_probmap"]),
98+
niu.IdentityInterface(fields=["in_t1w", "in_t2w", "in_mask", "in_probmap"]),
8899
name="inputnode",
89100
)
90101
outputnode = pe.Node(
@@ -100,15 +111,14 @@ def init_coregistration_wf(
100111
name="outputnode",
101112
)
102113

103-
fixed_masks_arg = pe.Node(niu.Merge(3), name="fixed_masks_arg", run_without_submitting=True)
104-
105114
# Dilate t2w mask for easier t1->t2 registration
115+
fixed_masks_arg = pe.Node(niu.Merge(3), name="fixed_masks_arg", run_without_submitting=True)
106116
reg_mask = pe.Node(BinaryDilation(radius=8, iterations=3), name="reg_mask")
107117
refine_mask = pe.Node(BinaryDilation(radius=8, iterations=1), name="refine_mask")
108118

109-
# Set up T2w -> T1w within-subject registration
119+
# Set up T1w -> T2w within-subject registration
110120
coreg = pe.Node(
111-
Registration(from_file=pkgr_fn("nibabies.data", "within_subject_t1t2.json")),
121+
Registration(from_file=get_file("nibabies", "data/within_subject_t1t2.json")),
112122
name="coreg",
113123
n_procs=omp_nthreads,
114124
mem_gb=mem_gb,
@@ -119,10 +129,6 @@ def init_coregistration_wf(
119129
coreg.inputs.output_inverse_warped_image = sloppy
120130
coreg.inputs.output_warped_image = sloppy
121131

122-
map_mask = pe.Node(ApplyTransforms(interpolation="Gaussian"), name="map_mask", mem_gb=1)
123-
map_t2w = pe.Node(ApplyTransforms(interpolation="BSpline"), name="map_t2w", mem_gb=1)
124-
thr_mask = pe.Node(Binarize(thresh_low=0.80), name="thr_mask")
125-
126132
final_n4 = pe.Node(
127133
N4BiasFieldCorrection(
128134
dimension=3,
@@ -137,40 +143,69 @@ def init_coregistration_wf(
137143
n_procs=omp_nthreads,
138144
name="final_n4",
139145
)
146+
# Move the T2w into T1w space, and apply the mask to the T1w
147+
map_t2w = pe.Node(ApplyTransforms(interpolation="BSpline"), name="map_t2w", mem_gb=1)
140148
apply_mask = pe.Node(ApplyMask(), name="apply_mask")
141149

142150
# fmt: off
143151
workflow.connect([
144-
(inputnode, map_mask, [("in_t1w", "reference_image")]),
145152
(inputnode, final_n4, [("in_t1w", "input_image")]),
146153
(inputnode, coreg, [("in_t1w", "moving_image"),
147-
("in_t2w_preproc", "fixed_image")]),
148-
(inputnode, map_mask, [("in_probmap", "input_image")]),
149-
(inputnode, reg_mask, [("in_mask", "in_file")]),
150-
(inputnode, refine_mask, [("in_mask", "in_file")]),
151-
(reg_mask, fixed_masks_arg, [("out_file", "in1")]),
152-
(reg_mask, fixed_masks_arg, [("out_file", "in2")]),
154+
("in_t2w", "fixed_image")]),
155+
(reg_mask, fixed_masks_arg, [
156+
("out_file", "in1"),
157+
("out_file", "in2")]),
153158
(refine_mask, fixed_masks_arg, [("out_file", "in3")]),
154-
(inputnode, map_t2w, [("in_t1w", "reference_image")]),
155-
(inputnode, map_t2w, [("in_t2w_preproc", "input_image")]),
159+
(inputnode, map_t2w, [
160+
("in_t1w", "reference_image"),
161+
("in_t2w", "input_image")]),
156162
(fixed_masks_arg, coreg, [("out", "fixed_image_masks")]),
157-
(coreg, map_mask, [
158-
("reverse_transforms", "transforms"),
159-
("reverse_invert_flags", "invert_transform_flags"),
160-
]),
161163
(coreg, map_t2w, [
162164
("reverse_transforms", "transforms"),
163165
("reverse_invert_flags", "invert_transform_flags"),
164166
]),
165-
(map_mask, thr_mask, [("output_image", "in_file")]),
166-
(map_mask, final_n4, [("output_image", "weight_image")]),
167167
(final_n4, apply_mask, [("output_image", "in_file")]),
168-
(thr_mask, apply_mask, [("out_mask", "in_mask")]),
169168
(final_n4, outputnode, [("output_image", "t1w_preproc")]),
170169
(map_t2w, outputnode, [("output_image", "t2w_preproc")]),
171-
(thr_mask, outputnode, [("out_mask", "t1w_mask")]),
172170
(apply_mask, outputnode, [("out_file", "t1w_brain")]),
173171
(coreg, outputnode, [("forward_transforms", "t1w2t2w_xfm")]),
174172
])
175173
# fmt: on
174+
175+
if precomputed_mask:
176+
# The input mask is already in T1w space.
177+
# Generate a quick, rough mask of the T2w to be used to facilitate co-registration.
178+
from sdcflows.interfaces.brainmask import BrainExtraction
179+
180+
masker = pe.Node(BrainExtraction(), name="t2w-masker")
181+
# fmt:off
182+
workflow.connect([
183+
(masker, reg_mask, [("out_mask", "in_file")]),
184+
(masker, refine_mask, [("out_mask", "in_file")]),
185+
(inputnode, apply_mask, [("in_mask", "in_mask")]),
186+
(inputnode, outputnode, [("in_mask", "t1w_mask")]),
187+
])
188+
# fmt:on
189+
else:
190+
# The T2w mask from the brain extraction workflow will be mapped to T1w space
191+
map_mask = pe.Node(ApplyTransforms(interpolation="Gaussian"), name="map_mask", mem_gb=1)
192+
thr_mask = pe.Node(Binarize(thresh_low=0.80), name="thr_mask")
193+
# fmt:off
194+
workflow.connect([
195+
(inputnode, map_mask, [
196+
("in_t1w", "reference_image"),
197+
("in_probmap", "input_image")]),
198+
(inputnode, reg_mask, [("in_mask", "in_file")]),
199+
(inputnode, refine_mask, [("in_mask", "in_file")]),
200+
(coreg, map_mask, [
201+
("reverse_transforms", "transforms"),
202+
("reverse_invert_flags", "invert_transform_flags")]),
203+
(map_mask, thr_mask, [("output_image", "in_file")]),
204+
(map_mask, final_n4, [("output_image", "weight_image")]),
205+
(final_n4, apply_mask, [("output_image", "in_file")]),
206+
(final_n4, outputnode, [("output_image", "t1w_preproc")]),
207+
(thr_mask, outputnode, [("out_mask", "t1w_mask")]),
208+
(thr_mask, apply_mask, [("out_mask", "in_mask")]),
209+
])
210+
# fmt:on
176211
return workflow

0 commit comments

Comments
 (0)