Skip to content

Commit 37bea41

Browse files
committed
RF: Ensure precomputed mask, aseg match anatomical template
1 parent 83eae58 commit 37bea41

File tree

1 file changed

+104
-2
lines changed

1 file changed

+104
-2
lines changed

nibabies/workflows/anatomical/template.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Prepare anatomical images for processing."""
2+
from __future__ import annotations
3+
24
from nipype.interfaces import utility as niu
35
from nipype.pipeline import engine as pe
46
from niworkflows.engine.workflows import LiterateWorkflow
7+
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
58

69

710
def init_anat_template_wf(
@@ -12,6 +15,8 @@ def init_anat_template_wf(
1215
longitudinal: bool = False,
1316
bspline_fitting_distance: int = 200,
1417
sloppy: bool = False,
18+
has_mask: bool = False,
19+
has_aseg: bool = False,
1520
name: str = "anat_template_wf",
1621
) -> LiterateWorkflow:
1722
"""
@@ -45,6 +50,11 @@ def init_anat_template_wf(
4550
------
4651
anat_files
4752
List of structural images
53+
anat_mask
54+
mask_reference
55+
anat_aseg
56+
aseg_reference
57+
4858
Outputs
4959
-------
5060
anat_ref
@@ -55,12 +65,15 @@ def init_anat_template_wf(
5565
List of affine transforms to realign input images to final reference
5666
out_report
5767
Conformation report
68+
anat_mask
69+
Mask (if provided), resampled to the anatomical reference
70+
anat_aseg
71+
Aseg (if provided), resampled to the anatomical reference
5872
"""
5973
from nipype.interfaces.ants import N4BiasFieldCorrection
6074
from nipype.interfaces.image import Reorient
6175
from niworkflows.interfaces.freesurfer import PatchedLTAConvert as LTAConvert
6276
from niworkflows.interfaces.freesurfer import StructuralReference
63-
from niworkflows.interfaces.header import ValidateImage
6477
from niworkflows.interfaces.images import Conform, TemplateDimensions
6578
from niworkflows.interfaces.nibabel import IntensityClip
6679
from niworkflows.interfaces.nitransforms import ConcatenateXFMs
@@ -80,7 +93,18 @@ def init_anat_template_wf(
8093
"""
8194

8295
inputnode = pe.Node(
83-
niu.IdentityInterface(fields=["anat_files", "anat_mask", "anat_aseg"]), name="inputnode"
96+
niu.IdentityInterface(
97+
fields=[
98+
"anat_files",
99+
# Each derivative requires a reference file, which will be used to find which
100+
# transform to apply in the case when multiple runs are present
101+
"anat_mask",
102+
"mask_reference",
103+
"anat_aseg",
104+
"aseg_reference",
105+
]
106+
),
107+
name="inputnode",
84108
)
85109
outputnode = pe.Node(
86110
niu.IdentityInterface(
@@ -89,6 +113,8 @@ def init_anat_template_wf(
89113
"anat_valid_list",
90114
"anat_realign_xfm",
91115
"out_report",
116+
"anat_mask",
117+
"anat_aseg",
92118
],
93119
),
94120
name="outputnode",
@@ -110,6 +136,28 @@ def init_anat_template_wf(
110136
])
111137
# fmt:on
112138

139+
if has_mask:
140+
mask_conform = pe.Node(Conform(), name='mask_conform')
141+
# fmt:off
142+
wf.connect([
143+
(inputnode, mask_conform, [('anat_mask', 'in_file')]),
144+
(anat_ref_dimensions, mask_conform, [
145+
('target_zooms', 'target_zooms'),
146+
('target_shape', 'target_shape')]),
147+
])
148+
# fmt:on
149+
150+
if has_aseg:
151+
aseg_conform = pe.Node(Conform(), name='aseg_conform')
152+
# fmt:off
153+
wf.connect([
154+
(inputnode, aseg_conform, [('anat_aseg', 'in_file')]),
155+
(anat_ref_dimensions, aseg_conform, [
156+
('target_zooms', 'target_zooms'),
157+
('target_shape', 'target_shape')]),
158+
])
159+
# fmt:on
160+
113161
if num_files == 1:
114162
get1st = pe.Node(niu.Select(index=[0]), name="get1st")
115163
outputnode.inputs.anat_realign_xfm = [
@@ -122,6 +170,10 @@ def init_anat_template_wf(
122170
(get1st, outputnode, [('out', 'anat_ref')]),
123171
])
124172
# fmt:on
173+
if has_mask:
174+
wf.connect(mask_conform, 'out_file', outputnode, 'anat_mask')
175+
if has_aseg:
176+
wf.connect(aseg_conform, 'out_file', outputnode, 'anat_aseg')
125177
return wf
126178

127179
anat_conform_xfm = pe.MapNode(
@@ -180,6 +232,52 @@ def init_anat_template_wf(
180232
run_without_submitting=True,
181233
)
182234

235+
if has_mask:
236+
mask_ref_idx = pe.Node(
237+
niu.Function(function=get_reference), name='mask_ref_idx', run_without_submitting=True
238+
)
239+
mask_xfm = pe.Node(niu.Select(), name='mask_xfm', run_without_submitting=True)
240+
applyxfm_mask = pe.Node(
241+
ApplyTransforms(interpolation='MultiLabel'), name='applyxfm_mask', mem_gb=1
242+
)
243+
mask_reorient = pe.Node(Reorient(), name="mask_reorient")
244+
# fmt:off
245+
wf.connect([
246+
(inputnode, mask_ref_idx, [('mask_reference', 'anat_reference')]),
247+
(anat_ref_dimensions, mask_ref_idx, [('t1w_valid_list', 'anatomicals')]),
248+
(concat_xfms, mask_xfm, [('out_xfm', 'inlist')]),
249+
(mask_ref_idx, mask_xfm, [('out', 'index')]),
250+
(mask_conform, applyxfm_mask, [('out_file', 'input_image')]),
251+
(anat_reorient, applyxfm_mask, [('out_file', 'reference_image')]),
252+
(mask_xfm, applyxfm_mask, [('out', 'transforms')]),
253+
(applyxfm_mask, mask_reorient, [('output_image', 'in_file')]),
254+
(mask_reorient, outputnode, [('out_file', 'anat_mask')]),
255+
])
256+
# fmt:on
257+
258+
if has_aseg:
259+
aseg_ref_idx = pe.Node(
260+
niu.Function(function=get_reference), name='aseg_ref_idx', run_without_submitting=True
261+
)
262+
aseg_xfm = pe.Node(niu.Select(), name='aseg_xfm', run_without_submitting=True)
263+
applyxfm_aseg = pe.Node(
264+
ApplyTransforms(interpolation='MultiLabel'), name='applyxfm_aseg', mem_gb=1
265+
)
266+
aseg_reorient = pe.Node(Reorient(), name="aseg_reorient")
267+
# fmt:off
268+
wf.connect([
269+
(inputnode, aseg_ref_idx, [('aseg_reference', 'anat_reference')]),
270+
(anat_ref_dimensions, aseg_ref_idx, [('t1w_valid_list', 'anatomicals')]),
271+
(concat_xfms, aseg_xfm, [('out_xfm', 'inlist')]),
272+
(aseg_ref_idx, aseg_xfm, [('out', 'index')]),
273+
(aseg_conform, applyxfm_aseg, [('out_file', 'input_image')]),
274+
(anat_reorient, applyxfm_aseg, [('out_file', 'reference_image')]),
275+
(aseg_xfm, applyxfm_aseg, [('out', 'transforms')]),
276+
(applyxfm_aseg, aseg_reorient, [('output_image', 'in_file')]),
277+
(applyxfm_aseg, outputnode, [('out_file', 'anat_aseg')]),
278+
])
279+
# fmt:on
280+
183281
def _set_threads(in_list, maximum):
184282
return min(len(in_list), maximum)
185283

@@ -204,3 +302,7 @@ def _set_threads(in_list, maximum):
204302
])
205303
# fmt:on
206304
return wf
305+
306+
307+
def get_reference(anatomicals: list, anat_reference: str) -> int:
308+
return anatomicals.index(anat_reference)

0 commit comments

Comments
 (0)