1
1
"""Base anatomical preprocessing."""
2
- import warnings
2
+ from __future__ import annotations
3
+
4
+ import typing as ty
3
5
from pathlib import Path
4
- from typing import Literal , Optional , Union
5
6
6
7
from nipype .interfaces import utility as niu
7
8
from nipype .pipeline import engine as pe
10
11
11
12
from ... import config
12
13
14
+ if ty .TYPE_CHECKING :
15
+ from nibabies .utils .bids import Derivatives
16
+
13
17
14
18
def init_infant_anat_wf (
15
19
* ,
16
- age_months : Optional [ int ] ,
20
+ age_months : int ,
17
21
ants_affine_init : bool ,
18
22
t1w : list ,
19
23
t2w : list ,
20
24
anat_modality : str ,
21
- bids_root : Optional [ Union [ str , Path ]] ,
22
- existing_derivatives : dict ,
25
+ bids_root : str | Path ,
26
+ derivatives : Derivatives ,
23
27
freesurfer : bool ,
24
- hires : Optional [ bool ] ,
28
+ hires : bool | None ,
25
29
longitudinal : bool ,
26
30
omp_nthreads : int ,
27
- output_dir : Union [ str , Path ] ,
28
- segmentation_atlases : Optional [ Union [ str , Path ]] ,
31
+ output_dir : str | Path ,
32
+ segmentation_atlases : str | Path | None ,
29
33
skull_strip_mode : str ,
30
34
skull_strip_template : Reference ,
31
35
sloppy : bool ,
32
- spaces : Optional [ SpatialReferences ] ,
33
- cifti_output : Optional [ Literal ['91k' , '170k' ]] ,
36
+ spaces : SpatialReferences | None ,
37
+ cifti_output : ty . Literal ['91k' , '170k' ] | None ,
34
38
name : str = "infant_anat_wf" ,
35
39
) -> LiterateWorkflow :
36
40
"""
@@ -93,7 +97,7 @@ def init_infant_anat_wf(
93
97
init_coreg_report_wf ,
94
98
)
95
99
from .preproc import init_anat_preproc_wf
96
- from .registration import init_coregistration_wf
100
+ from .registration import init_coregister_derivatives_wf , init_coregistration_wf
97
101
from .segmentation import init_anat_segmentations_wf
98
102
from .surfaces import init_anat_ribbon_wf
99
103
from .template import init_anat_template_wf
@@ -102,28 +106,9 @@ def init_infant_anat_wf(
102
106
num_t1w = len (t1w ) if t1w else 0
103
107
num_t2w = len (t2w ) if t2w else 0
104
108
105
- precomp_mask = existing_derivatives .get ("anat_mask" )
106
- precomp_aseg = existing_derivatives .get ("anat_aseg" )
107
-
108
- # verify derivatives are relatively similar to T1w
109
- if precomp_mask or precomp_aseg :
110
- if num_t1w > 1 :
111
- precomp_mask = None
112
- precomp_aseg = None
113
- warnings .warn (
114
- "Multiple T1w files were found; precomputed derivatives will not be used."
115
- )
116
-
117
- else :
118
- from ...utils .validation import validate_t1w_derivatives
119
-
120
- validated_derivatives = (
121
- validate_t1w_derivatives ( # compare derivatives to the first T1w
122
- t1w [0 ], anat_mask = precomp_mask , anat_aseg = precomp_aseg
123
- )
124
- )
125
- precomp_mask = validated_derivatives .get ("anat_mask" )
126
- precomp_aseg = validated_derivatives .get ("anat_aseg" )
109
+ # Expected derivatives: Prioritize T1w space if available, otherwise fall back to T2w
110
+ deriv_mask = derivatives .mask
111
+ deriv_aseg = derivatives .aseg
127
112
128
113
wf = LiterateWorkflow (name = name )
129
114
desc = f"""\n
@@ -186,7 +171,7 @@ def init_infant_anat_wf(
186
171
187
172
desc += (
188
173
"A previously computed mask was used to skull-strip the anatomical image."
189
- if precomp_mask
174
+ if deriv_mask
190
175
else """\
191
176
The T1w-reference was then skull-stripped with a modified implementation of
192
177
the `antsBrainExtraction.sh` workflow (from ANTs), using {skullstrip_tpl}
@@ -215,12 +200,19 @@ def init_infant_anat_wf(
215
200
)
216
201
217
202
# Multiple anatomical files -> generate average reference
203
+ t1w_mask = bool (derivatives .t1w_mask )
204
+ t1w_aseg = bool (derivatives .t1w_aseg )
205
+ t2w_mask = bool (derivatives .t2w_mask )
206
+ t2w_aseg = bool (derivatives .t2w_aseg )
207
+
218
208
t1w_template_wf = init_anat_template_wf (
219
209
contrast = "T1w" ,
220
210
num_files = num_t1w ,
221
211
longitudinal = longitudinal ,
222
212
omp_nthreads = omp_nthreads ,
223
213
sloppy = sloppy ,
214
+ has_mask = t1w_mask ,
215
+ has_aseg = t1w_aseg ,
224
216
name = "t1w_template_wf" ,
225
217
)
226
218
@@ -230,16 +222,14 @@ def init_infant_anat_wf(
230
222
longitudinal = longitudinal ,
231
223
omp_nthreads = omp_nthreads ,
232
224
sloppy = sloppy ,
225
+ has_mask = t2w_mask ,
226
+ has_aseg = t2w_aseg ,
233
227
name = "t2w_template_wf" ,
234
228
)
235
229
236
230
# Clean up each anatomical template
237
231
# Denoise, INU, + Clipping
238
- t1w_preproc_wf = init_anat_preproc_wf (
239
- precomputed_mask = bool (precomp_mask ),
240
- precomputed_aseg = bool (precomp_aseg ),
241
- name = "t1w_preproc_wf" ,
242
- )
232
+ t1w_preproc_wf = init_anat_preproc_wf (name = "t1w_preproc_wf" )
243
233
t2w_preproc_wf = init_anat_preproc_wf (name = "t2w_preproc_wf" )
244
234
245
235
if skull_strip_mode != "force" :
@@ -249,7 +239,8 @@ def init_infant_anat_wf(
249
239
omp_nthreads = omp_nthreads ,
250
240
sloppy = sloppy ,
251
241
debug = "registration" in config .execution .debug ,
252
- precomputed_mask = bool (precomp_mask ),
242
+ t1w_mask = t1w_mask ,
243
+ probmap = not t2w_mask ,
253
244
)
254
245
coreg_report_wf = init_coreg_report_wf (
255
246
output_dir = output_dir ,
@@ -261,7 +252,7 @@ def init_infant_anat_wf(
261
252
template_dir = segmentation_atlases ,
262
253
sloppy = sloppy ,
263
254
omp_nthreads = omp_nthreads ,
264
- precomp_aseg = precomp_aseg ,
255
+ precomp_aseg = bool ( derivatives . aseg ) ,
265
256
)
266
257
267
258
# Spatial normalization (requires segmentation)
@@ -347,15 +338,41 @@ def init_infant_anat_wf(
347
338
]),
348
339
])
349
340
350
- if precomp_mask :
351
- # Ensure the mask is conformed along with the T1w
352
- t1w_preproc_wf .inputs .inputnode .in_mask = precomp_mask
353
- # fmt:off
354
- wf .connect ([
355
- (t1w_preproc_wf , coregistration_wf , [("outputnode.anat_mask" , "inputnode.in_mask" )]),
356
- (t2w_preproc_wf , coregistration_wf , [("outputnode.anat_preproc" , "inputnode.in_t2w" )])
357
- ])
358
- # fmt:on
341
+ # Workflow to move derivatives between T1w/T2w spaces
342
+ # May not be used, but define in case necessary.
343
+ coreg_deriv_wf = init_coregister_derivatives_wf (
344
+ t1w_mask = t1w_mask , t1w_aseg = t1w_aseg , t2w_aseg = t2w_aseg
345
+ )
346
+ deriv_buffer = pe .Node (
347
+ niu .IdentityInterface (fields = ['t2w_mask' , 't1w_aseg' , 't2w_aseg' ]),
348
+ name = 'deriv_buffer' ,
349
+ )
350
+ if derivatives :
351
+ wf .connect (
352
+ coregistration_wf , 'outputnode.t1w2t2w_xfm' , coreg_deriv_wf , 'inputnode.t1w2t2w_xfm'
353
+ )
354
+
355
+ # Derivative mask is present
356
+ if derivatives .mask :
357
+ if t1w_mask :
358
+ t1w_template_wf .inputs .inputnode .anat_mask = derivatives .t1w_mask
359
+ t1w_template_wf .inputs .inputnode .mask_reference = derivatives .references ['t1w_mask' ]
360
+ # fmt:off
361
+ wf .connect ([
362
+ (t1w_template_wf , coregistration_wf , [('outputnode.anat_mask' , 'inputnode.in_mask' )]),
363
+ (t2w_preproc_wf , coregistration_wf , [('outputnode.anat_preproc' , 'inputnode.in_t2w' )]),
364
+ (t1w_template_wf , coreg_deriv_wf , [('outputnode.anat_mask' , 'inputnode.t1w_mask' )]),
365
+ (coreg_deriv_wf , deriv_buffer , [('outputnode.t2w_mask' , 't2w_mask' )])
366
+ ])
367
+ # fmt:on
368
+ elif t2w_mask :
369
+ t2w_template_wf .inputs .inputnode .anat_mask = derivatives .t2w_mask
370
+ t2w_template_wf .inputs .inputnode .mask_reference = derivatives .references ['t2w_mask' ]
371
+
372
+ wf .connect ([
373
+ (t2w_template_wf , coregistration_wf , [('outputnode.anat_mask' , 'inputnode.in_mask' )]),
374
+ (t2w_template_wf , deriv_buffer , [('outputnode.anat_mask' , 't2w_mask' )]),
375
+ ])
359
376
else :
360
377
# Run brain extraction on the T2w
361
378
brain_extraction_wf = init_infant_brain_extraction_wf (
@@ -378,10 +395,30 @@ def init_infant_anat_wf(
378
395
])
379
396
# fmt:on
380
397
381
- if precomp_aseg :
382
- # Ensure the segmentation is conformed along with the T1w
383
- t1w_preproc_wf .inputs .inputnode .in_aseg = precomp_aseg
384
- wf .connect (t1w_preproc_wf , "outputnode.anat_aseg" , anat_seg_wf , "inputnode.anat_aseg" )
398
+ # Derivative segmentation is present
399
+ if derivatives .aseg :
400
+ wf .connect (deriv_buffer , 't1w_aseg' , anat_seg_wf , 'inputnode.anat_aseg' )
401
+
402
+ if t1w_aseg :
403
+ t1w_template_wf .inputs .inputnode .anat_aseg = derivatives .t1w_aseg
404
+ t1w_template_wf .inputs .inputnode .aseg_reference = derivatives .references ['t1w_aseg' ]
405
+ # fmt:off
406
+ wf .connect ([
407
+ (t1w_template_wf , deriv_buffer , [('outputnode.anat_aseg' , 't1w_aseg' )]),
408
+ (t1w_template_wf , coreg_deriv_wf , [('outputnode.anat_aseg' , 'inputnode.t1w_aseg' )]),
409
+ (coreg_deriv_wf , deriv_buffer , [('outputnode.t2w_aseg' , 't2w_aseg' )]),
410
+ ])
411
+ # fmt:on
412
+ elif t2w_aseg :
413
+ t2w_template_wf .inputs .inputnode .anat_aseg = derivatives .t2w_aseg
414
+ t2w_template_wf .inputs .inputnode .aseg_reference = derivatives .references ['t2w_aseg' ]
415
+ # fmt:off
416
+ wf .connect ([
417
+ (t2w_template_wf , deriv_buffer , [('outputnode.anat_aseg' , 't2w_aseg' )]),
418
+ (t2w_template_wf , coreg_deriv_wf , [('outputnode.anat_aseg' , 'inputnode.t2w_aseg' )]),
419
+ (coreg_deriv_wf , deriv_buffer , [('outputnode.t1w_aseg' , 't1w_aseg' )]),
420
+ ])
421
+ # fmt:on
385
422
386
423
if not freesurfer :
387
424
return wf
@@ -394,7 +431,7 @@ def init_infant_anat_wf(
394
431
from .surfaces import init_infantfs_surface_recon_wf
395
432
396
433
# if running with precomputed aseg, or JLF, pass the aseg along to FreeSurfer
397
- use_aseg = bool (precomp_aseg or segmentation_atlases )
434
+ use_aseg = bool (derivatives . aseg or segmentation_atlases )
398
435
surface_recon_wf = init_infantfs_surface_recon_wf (
399
436
age_months = age_months ,
400
437
use_aseg = use_aseg ,
@@ -405,32 +442,30 @@ def init_infant_anat_wf(
405
442
406
443
from .surfaces import init_mcribs_sphere_reg_wf , init_mcribs_surface_recon_wf
407
444
408
- # Denoise raw T2w, since using the template / preproc resulted in intersection errors
409
- denoise_raw_t2w = pe .Node (
410
- DenoiseImage (dimension = 3 , noise_model = "Rician" ), name = 'denoise_raw_t2w '
445
+ # Denoise template T2w, since using the template / preproc resulted in intersection errors
446
+ denoise_t2w = pe .Node (
447
+ DenoiseImage (dimension = 3 , noise_model = "Rician" ), name = 'denoise_t2w '
411
448
)
412
-
449
+ # t2w mask, t2w aseg
413
450
surface_recon_wf = init_mcribs_surface_recon_wf (
414
451
omp_nthreads = omp_nthreads ,
415
- use_aseg = bool (precomp_aseg ),
416
- use_mask = bool (precomp_mask ),
452
+ use_aseg = bool (derivatives . aseg ), # TODO: Incorporate mcribs segmentation
453
+ use_mask = bool (derivatives . mask ), # TODO: Pass in mask regardless of derivatives
417
454
mcribs_dir = str (config .execution .mcribs_dir ), # Needed to preserve runs
418
455
)
419
-
420
456
# M-CRIB-S to dHCP42week (32k)
421
457
sphere_reg_wf = init_mcribs_sphere_reg_wf ()
422
458
423
- # Transformed gives
424
- if precomp_aseg :
425
- surface_recon_wf .inputs .inputnode .ants_segs = precomp_aseg
426
- if precomp_mask :
427
- surface_recon_wf .inputs .inputnode .anat_mask = precomp_mask
428
459
# fmt:off
429
460
wf .connect ([
430
- (inputnode , denoise_raw_t2w , [('t2w ' , 'input_image' )]),
431
- (denoise_raw_t2w , surface_recon_wf , [('output_image' , 'inputnode.t2w' )]),
461
+ (t2w_template_wf , denoise_t2w , [('outputnode.anat_ref ' , 'input_image' )]),
462
+ (denoise_t2w , surface_recon_wf , [('output_image' , 'inputnode.t2w' )]),
432
463
])
433
464
# fmt:on
465
+ if derivatives .aseg :
466
+ wf .connect (deriv_buffer , 't2w_aseg' , surface_recon_wf , 'inputnode.ants_segs' )
467
+ if derivatives .mask :
468
+ wf .connect (deriv_buffer , 't2w_mask' , surface_recon_wf , 'inputnode.anat_mask' )
434
469
else :
435
470
raise NotImplementedError
436
471
0 commit comments