Skip to content

Commit 2ba3d18

Browse files
committed
ENH: Support either T1w/T2w anatomical derivatives
1 parent 44edcb9 commit 2ba3d18

File tree

1 file changed

+104
-69
lines changed
  • nibabies/workflows/anatomical

1 file changed

+104
-69
lines changed

nibabies/workflows/anatomical/base.py

Lines changed: 104 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Base anatomical preprocessing."""
2-
import warnings
2+
from __future__ import annotations
3+
4+
import typing as ty
35
from pathlib import Path
4-
from typing import Literal, Optional, Union
56

67
from nipype.interfaces import utility as niu
78
from nipype.pipeline import engine as pe
@@ -10,27 +11,30 @@
1011

1112
from ... import config
1213

14+
if ty.TYPE_CHECKING:
15+
from nibabies.utils.bids import Derivatives
16+
1317

1418
def init_infant_anat_wf(
1519
*,
16-
age_months: Optional[int],
20+
age_months: int,
1721
ants_affine_init: bool,
1822
t1w: list,
1923
t2w: list,
2024
anat_modality: str,
21-
bids_root: Optional[Union[str, Path]],
22-
existing_derivatives: dict,
25+
bids_root: str | Path,
26+
derivatives: Derivatives,
2327
freesurfer: bool,
24-
hires: Optional[bool],
28+
hires: bool | None,
2529
longitudinal: bool,
2630
omp_nthreads: int,
27-
output_dir: Union[str, Path],
28-
segmentation_atlases: Optional[Union[str, Path]],
31+
output_dir: str | Path,
32+
segmentation_atlases: str | Path | None,
2933
skull_strip_mode: str,
3034
skull_strip_template: Reference,
3135
sloppy: bool,
32-
spaces: Optional[SpatialReferences],
33-
cifti_output: Optional[Literal['91k', '170k']],
36+
spaces: SpatialReferences | None,
37+
cifti_output: ty.Literal['91k', '170k'] | None,
3438
name: str = "infant_anat_wf",
3539
) -> LiterateWorkflow:
3640
"""
@@ -93,7 +97,7 @@ def init_infant_anat_wf(
9397
init_coreg_report_wf,
9498
)
9599
from .preproc import init_anat_preproc_wf
96-
from .registration import init_coregistration_wf
100+
from .registration import init_coregister_derivatives_wf, init_coregistration_wf
97101
from .segmentation import init_anat_segmentations_wf
98102
from .surfaces import init_anat_ribbon_wf
99103
from .template import init_anat_template_wf
@@ -102,28 +106,9 @@ def init_infant_anat_wf(
102106
num_t1w = len(t1w) if t1w else 0
103107
num_t2w = len(t2w) if t2w else 0
104108

105-
precomp_mask = existing_derivatives.get("anat_mask")
106-
precomp_aseg = existing_derivatives.get("anat_aseg")
107-
108-
# verify derivatives are relatively similar to T1w
109-
if precomp_mask or precomp_aseg:
110-
if num_t1w > 1:
111-
precomp_mask = None
112-
precomp_aseg = None
113-
warnings.warn(
114-
"Multiple T1w files were found; precomputed derivatives will not be used."
115-
)
116-
117-
else:
118-
from ...utils.validation import validate_t1w_derivatives
119-
120-
validated_derivatives = (
121-
validate_t1w_derivatives( # compare derivatives to the first T1w
122-
t1w[0], anat_mask=precomp_mask, anat_aseg=precomp_aseg
123-
)
124-
)
125-
precomp_mask = validated_derivatives.get("anat_mask")
126-
precomp_aseg = validated_derivatives.get("anat_aseg")
109+
# Expected derivatives: Prioritize T1w space if available, otherwise fall back to T2w
110+
deriv_mask = derivatives.mask
111+
deriv_aseg = derivatives.aseg
127112

128113
wf = LiterateWorkflow(name=name)
129114
desc = f"""\n
@@ -186,7 +171,7 @@ def init_infant_anat_wf(
186171

187172
desc += (
188173
"A previously computed mask was used to skull-strip the anatomical image."
189-
if precomp_mask
174+
if deriv_mask
190175
else """\
191176
The T1w-reference was then skull-stripped with a modified implementation of
192177
the `antsBrainExtraction.sh` workflow (from ANTs), using {skullstrip_tpl}
@@ -215,12 +200,19 @@ def init_infant_anat_wf(
215200
)
216201

217202
# Multiple anatomical files -> generate average reference
203+
t1w_mask = bool(derivatives.t1w_mask)
204+
t1w_aseg = bool(derivatives.t1w_aseg)
205+
t2w_mask = bool(derivatives.t2w_mask)
206+
t2w_aseg = bool(derivatives.t2w_aseg)
207+
218208
t1w_template_wf = init_anat_template_wf(
219209
contrast="T1w",
220210
num_files=num_t1w,
221211
longitudinal=longitudinal,
222212
omp_nthreads=omp_nthreads,
223213
sloppy=sloppy,
214+
has_mask=t1w_mask,
215+
has_aseg=t1w_aseg,
224216
name="t1w_template_wf",
225217
)
226218

@@ -230,16 +222,14 @@ def init_infant_anat_wf(
230222
longitudinal=longitudinal,
231223
omp_nthreads=omp_nthreads,
232224
sloppy=sloppy,
225+
has_mask=t2w_mask,
226+
has_aseg=t2w_aseg,
233227
name="t2w_template_wf",
234228
)
235229

236230
# Clean up each anatomical template
237231
# Denoise, INU, + Clipping
238-
t1w_preproc_wf = init_anat_preproc_wf(
239-
precomputed_mask=bool(precomp_mask),
240-
precomputed_aseg=bool(precomp_aseg),
241-
name="t1w_preproc_wf",
242-
)
232+
t1w_preproc_wf = init_anat_preproc_wf(name="t1w_preproc_wf")
243233
t2w_preproc_wf = init_anat_preproc_wf(name="t2w_preproc_wf")
244234

245235
if skull_strip_mode != "force":
@@ -249,7 +239,8 @@ def init_infant_anat_wf(
249239
omp_nthreads=omp_nthreads,
250240
sloppy=sloppy,
251241
debug="registration" in config.execution.debug,
252-
precomputed_mask=bool(precomp_mask),
242+
t1w_mask=t1w_mask,
243+
probmap=not t2w_mask,
253244
)
254245
coreg_report_wf = init_coreg_report_wf(
255246
output_dir=output_dir,
@@ -261,7 +252,7 @@ def init_infant_anat_wf(
261252
template_dir=segmentation_atlases,
262253
sloppy=sloppy,
263254
omp_nthreads=omp_nthreads,
264-
precomp_aseg=precomp_aseg,
255+
precomp_aseg=bool(derivatives.aseg),
265256
)
266257

267258
# Spatial normalization (requires segmentation)
@@ -347,15 +338,41 @@ def init_infant_anat_wf(
347338
]),
348339
])
349340

350-
if precomp_mask:
351-
# Ensure the mask is conformed along with the T1w
352-
t1w_preproc_wf.inputs.inputnode.in_mask = precomp_mask
353-
# fmt:off
354-
wf.connect([
355-
(t1w_preproc_wf, coregistration_wf, [("outputnode.anat_mask", "inputnode.in_mask")]),
356-
(t2w_preproc_wf, coregistration_wf, [("outputnode.anat_preproc", "inputnode.in_t2w")])
357-
])
358-
# fmt:on
341+
# Workflow to move derivatives between T1w/T2w spaces
342+
# May not be used, but define in case necessary.
343+
coreg_deriv_wf = init_coregister_derivatives_wf(
344+
t1w_mask=t1w_mask, t1w_aseg=t1w_aseg, t2w_aseg=t2w_aseg
345+
)
346+
deriv_buffer = pe.Node(
347+
niu.IdentityInterface(fields=['t2w_mask', 't1w_aseg', 't2w_aseg']),
348+
name='deriv_buffer',
349+
)
350+
if derivatives:
351+
wf.connect(
352+
coregistration_wf, 'outputnode.t1w2t2w_xfm', coreg_deriv_wf, 'inputnode.t1w2t2w_xfm'
353+
)
354+
355+
# Derivative mask is present
356+
if derivatives.mask:
357+
if t1w_mask:
358+
t1w_template_wf.inputs.inputnode.anat_mask = derivatives.t1w_mask
359+
t1w_template_wf.inputs.inputnode.mask_reference = derivatives.references['t1w_mask']
360+
# fmt:off
361+
wf.connect([
362+
(t1w_template_wf, coregistration_wf, [('outputnode.anat_mask', 'inputnode.in_mask')]),
363+
(t2w_preproc_wf, coregistration_wf, [('outputnode.anat_preproc', 'inputnode.in_t2w')]),
364+
(t1w_template_wf, coreg_deriv_wf, [('outputnode.anat_mask', 'inputnode.t1w_mask')]),
365+
(coreg_deriv_wf, deriv_buffer, [('outputnode.t2w_mask', 't2w_mask')])
366+
])
367+
# fmt:on
368+
elif t2w_mask:
369+
t2w_template_wf.inputs.inputnode.anat_mask = derivatives.t2w_mask
370+
t2w_template_wf.inputs.inputnode.mask_reference = derivatives.references['t2w_mask']
371+
372+
wf.connect([
373+
(t2w_template_wf, coregistration_wf, [('outputnode.anat_mask', 'inputnode.in_mask')]),
374+
(t2w_template_wf, deriv_buffer, [('outputnode.anat_mask', 't2w_mask')]),
375+
])
359376
else:
360377
# Run brain extraction on the T2w
361378
brain_extraction_wf = init_infant_brain_extraction_wf(
@@ -378,10 +395,30 @@ def init_infant_anat_wf(
378395
])
379396
# fmt:on
380397

381-
if precomp_aseg:
382-
# Ensure the segmentation is conformed along with the T1w
383-
t1w_preproc_wf.inputs.inputnode.in_aseg = precomp_aseg
384-
wf.connect(t1w_preproc_wf, "outputnode.anat_aseg", anat_seg_wf, "inputnode.anat_aseg")
398+
# Derivative segmentation is present
399+
if derivatives.aseg:
400+
wf.connect(deriv_buffer, 't1w_aseg', anat_seg_wf, 'inputnode.anat_aseg')
401+
402+
if t1w_aseg:
403+
t1w_template_wf.inputs.inputnode.anat_aseg = derivatives.t1w_aseg
404+
t1w_template_wf.inputs.inputnode.aseg_reference = derivatives.references['t1w_aseg']
405+
# fmt:off
406+
wf.connect([
407+
(t1w_template_wf, deriv_buffer, [('outputnode.anat_aseg', 't1w_aseg')]),
408+
(t1w_template_wf, coreg_deriv_wf, [('outputnode.anat_aseg', 'inputnode.t1w_aseg')]),
409+
(coreg_deriv_wf, deriv_buffer, [('outputnode.t2w_aseg', 't2w_aseg')]),
410+
])
411+
# fmt:on
412+
elif t2w_aseg:
413+
t2w_template_wf.inputs.inputnode.anat_aseg = derivatives.t2w_aseg
414+
t2w_template_wf.inputs.inputnode.aseg_reference = derivatives.references['t2w_aseg']
415+
# fmt:off
416+
wf.connect([
417+
(t2w_template_wf, deriv_buffer, [('outputnode.anat_aseg', 't2w_aseg')]),
418+
(t2w_template_wf, coreg_deriv_wf, [('outputnode.anat_aseg', 'inputnode.t2w_aseg')]),
419+
(coreg_deriv_wf, deriv_buffer, [('outputnode.t1w_aseg', 't1w_aseg')]),
420+
])
421+
# fmt:on
385422

386423
if not freesurfer:
387424
return wf
@@ -394,7 +431,7 @@ def init_infant_anat_wf(
394431
from .surfaces import init_infantfs_surface_recon_wf
395432

396433
# if running with precomputed aseg, or JLF, pass the aseg along to FreeSurfer
397-
use_aseg = bool(precomp_aseg or segmentation_atlases)
434+
use_aseg = bool(derivatives.aseg or segmentation_atlases)
398435
surface_recon_wf = init_infantfs_surface_recon_wf(
399436
age_months=age_months,
400437
use_aseg=use_aseg,
@@ -405,32 +442,30 @@ def init_infant_anat_wf(
405442

406443
from .surfaces import init_mcribs_sphere_reg_wf, init_mcribs_surface_recon_wf
407444

408-
# Denoise raw T2w, since using the template / preproc resulted in intersection errors
409-
denoise_raw_t2w = pe.Node(
410-
DenoiseImage(dimension=3, noise_model="Rician"), name='denoise_raw_t2w'
445+
# Denoise template T2w, since using the template / preproc resulted in intersection errors
446+
denoise_t2w = pe.Node(
447+
DenoiseImage(dimension=3, noise_model="Rician"), name='denoise_t2w'
411448
)
412-
449+
# t2w mask, t2w aseg
413450
surface_recon_wf = init_mcribs_surface_recon_wf(
414451
omp_nthreads=omp_nthreads,
415-
use_aseg=bool(precomp_aseg),
416-
use_mask=bool(precomp_mask),
452+
use_aseg=bool(derivatives.aseg), # TODO: Incorporate mcribs segmentation
453+
use_mask=bool(derivatives.mask), # TODO: Pass in mask regardless of derivatives
417454
mcribs_dir=str(config.execution.mcribs_dir), # Needed to preserve runs
418455
)
419-
420456
# M-CRIB-S to dHCP42week (32k)
421457
sphere_reg_wf = init_mcribs_sphere_reg_wf()
422458

423-
# Transformed gives
424-
if precomp_aseg:
425-
surface_recon_wf.inputs.inputnode.ants_segs = precomp_aseg
426-
if precomp_mask:
427-
surface_recon_wf.inputs.inputnode.anat_mask = precomp_mask
428459
# fmt:off
429460
wf.connect([
430-
(inputnode, denoise_raw_t2w, [('t2w', 'input_image')]),
431-
(denoise_raw_t2w, surface_recon_wf, [('output_image', 'inputnode.t2w')]),
461+
(t2w_template_wf, denoise_t2w, [('outputnode.anat_ref', 'input_image')]),
462+
(denoise_t2w, surface_recon_wf, [('output_image', 'inputnode.t2w')]),
432463
])
433464
# fmt:on
465+
if derivatives.aseg:
466+
wf.connect(deriv_buffer, 't2w_aseg', surface_recon_wf, 'inputnode.ants_segs')
467+
if derivatives.mask:
468+
wf.connect(deriv_buffer, 't2w_mask', surface_recon_wf, 'inputnode.anat_mask')
434469
else:
435470
raise NotImplementedError
436471

0 commit comments

Comments
 (0)