Skip to content

Commit ce1ab4a

Browse files
committed
ENH: Reflect downstream changes in anatomical preprocessing workflow
1 parent adb5c57 commit ce1ab4a

File tree

1 file changed

+120
-136
lines changed
  • nibabies/workflows/anatomical

1 file changed

+120
-136
lines changed

nibabies/workflows/anatomical/base.py

Lines changed: 120 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Base anatomical preprocessing."""
2+
import warnings
3+
24
from nipype.interfaces import utility as niu
35
from nipype.pipeline import engine as pe
46

@@ -93,6 +95,7 @@ def init_infant_anat_wf(
9395
init_anat_reports_wf,
9496
init_coreg_report_wf,
9597
)
98+
from .preproc import init_anat_preproc_wf
9699
from .registration import init_coregistration_wf
97100
from .segmentation import init_anat_segmentations_wf
98101
from .surfaces import init_anat_ribbon_wf
@@ -107,13 +110,23 @@ def init_infant_anat_wf(
107110

108111
# verify derivatives are relatively similar to T1w
109112
if precomp_mask or precomp_aseg:
110-
from ...utils.validation import validate_t1w_derivatives
111-
112-
validated_derivatives = validate_t1w_derivatives( # compare derivatives to the first T1w
113-
t1w[0], anat_mask=precomp_mask, anat_aseg=precomp_aseg
114-
)
115-
precomp_mask = validated_derivatives.get("anat_mask")
116-
precomp_aseg = validated_derivatives.get("anat_aseg")
113+
if num_t1w > 1:
114+
precomp_mask = None
115+
precomp_aseg = None
116+
warnings.warn(
117+
"Multiple T1w files were found; precomputed derivatives will not be used."
118+
)
119+
120+
else:
121+
from ...utils.validation import validate_t1w_derivatives
122+
123+
validated_derivatives = (
124+
validate_t1w_derivatives( # compare derivatives to the first T1w
125+
t1w[0], anat_mask=precomp_mask, anat_aseg=precomp_aseg
126+
)
127+
)
128+
precomp_mask = validated_derivatives.get("anat_mask")
129+
precomp_aseg = validated_derivatives.get("anat_aseg")
117130

118131
wf = Workflow(name=name)
119132
desc = f"""\n
@@ -159,10 +172,10 @@ def init_infant_anat_wf(
159172

160173
desc += (
161174
"""\
162-
All of the T1-weighted images were corrected for intensity non-uniformity (INU)"""
175+
All of the T1-weighted images were denoised and corrected for intensity non-uniformity (INU)"""
163176
if num_t1w > 1
164177
else """\
165-
The T1-weighted (T1w) image was corrected for intensity non-uniformity (INU)"""
178+
The T1-weighted (T1w) image was denoised and corrected for intensity non-uniformity (INU)"""
166179
)
167180

168181
desc += """\
@@ -200,13 +213,15 @@ def init_infant_anat_wf(
200213
cifti_output=cifti_output,
201214
)
202215

203-
# Multiple T1w files -> generate average reference
216+
# Multiple anatomical files -> generate average reference
204217
t1w_template_wf = init_anat_template_wf(
205218
contrast="T1w",
206219
num_files=num_t1w,
207220
longitudinal=longitudinal,
208221
omp_nthreads=omp_nthreads,
209222
sloppy=sloppy,
223+
precomputed_mask=bool(precomp_mask),
224+
precomputed_aseg=bool(precomp_aseg),
210225
name="t1w_template_wf",
211226
)
212227

@@ -219,38 +234,27 @@ def init_infant_anat_wf(
219234
name="t2w_template_wf",
220235
)
221236

222-
# INU + Brain Extraction
237+
# Clean up each anatomical template
238+
# Denoise, INU, + Clipping
239+
t1w_preproc_wf = init_anat_preproc_wf(name="t1w_preproc_wf")
240+
t2w_preproc_wf = init_anat_preproc_wf(name="t2w_preproc_wf")
241+
223242
if skull_strip_mode != "force":
224243
raise NotImplementedError("Skull stripping is currently required.")
225244

226-
if precomp_mask:
227-
precomp_mask_wf = init_precomputed_mask_wf(omp_nthreads=omp_nthreads)
228-
precomp_mask_wf.inputs.inputnode.t1w_mask = precomp_mask
229-
sdc_brain_extraction_wf = init_sdc_brain_extraction_wf(
230-
name="sdc_brain_extraction_wf",
231-
)
232-
brain_extraction_wf = init_infant_brain_extraction_wf(
233-
age_months=age_months,
234-
ants_affine_init=ants_affine_init,
235-
skull_strip_template=skull_strip_template.space,
236-
template_specs=skull_strip_template.spec,
237-
omp_nthreads=omp_nthreads,
238-
sloppy=sloppy,
239-
debug="registration" in config.execution.debug,
240-
)
241245
coregistration_wf = init_coregistration_wf(
242246
omp_nthreads=omp_nthreads,
243247
sloppy=sloppy,
244248
debug="registration" in config.execution.debug,
249+
precomputed_mask=bool(precomp_mask),
245250
)
246251
coreg_report_wf = init_coreg_report_wf(
247252
output_dir=output_dir,
248253
)
249-
t1w_preproc_wf = precomp_mask_wf if precomp_mask else coregistration_wf
250254

251255
# Segmentation - initial implementation should be simple: JLF
252256
anat_seg_wf = init_anat_segmentations_wf(
253-
anat_modality=anat_modality.capitalize(),
257+
anat_modality=anat_modality.capitalize(), # TODO: Revisit this option
254258
template_dir=segmentation_atlases,
255259
sloppy=sloppy,
256260
omp_nthreads=omp_nthreads,
@@ -264,143 +268,117 @@ def init_infant_anat_wf(
264268
templates=spaces.get_spaces(nonstandard=False, dim=(3,)),
265269
)
266270

267-
# Anatomical ribbon file using HCP signed-distance volume method
268-
# if config.workflow.project_goodvoxels:
269-
anat_ribbon_wf = init_anat_ribbon_wf()
270-
271271
# fmt:off
272272
wf.connect([
273273
(inputnode, t1w_template_wf, [("t1w", "inputnode.in_files")]),
274+
(inputnode, t2w_template_wf, [("t2w", "inputnode.in_files")]),
275+
(inputnode, anat_reports_wf, [("t1w", "inputnode.source_file")]),
276+
(inputnode, coreg_report_wf, [("t1w", "inputnode.source_file")]),
277+
(inputnode, anat_norm_wf, [(("t1w", fix_multi_source_name), "inputnode.orig_t1w")]),
278+
274279
(t1w_template_wf, outputnode, [
275-
("outputnode.realign_xfms", "anat_ref_xfms"),
276-
]),
280+
("outputnode.realign_xfms", "anat_ref_xfms")]),
281+
(t1w_template_wf, t1w_preproc_wf, [("outputnode.out_file", "inputnode.in_anat")]),
282+
(t2w_template_wf, t2w_preproc_wf, [("outputnode.out_file", "inputnode.in_anat")]),
283+
(t1w_template_wf, anat_derivatives_wf, [
284+
("outputnode.valid_list", "inputnode.source_files"),
285+
("outputnode.realign_xfms", "inputnode.t1w_ref_xfms")]),
286+
(t2w_template_wf, anat_derivatives_wf, [
287+
("outputnode.valid_list", "inputnode.t2w_source_files")]),
288+
289+
(t1w_preproc_wf, coregistration_wf, [("outputnode.anat_preproc", "inputnode.in_t1w")]),
290+
(t1w_preproc_wf, coreg_report_wf, [("outputnode.anat_preproc", "inputnode.t1w_preproc")]),
291+
(t1w_preproc_wf, anat_norm_wf, [
292+
("outputnode.t1w_preproc", "inputnode.moving_image"),
293+
("outputnode.t1w_mask", "inputnode.moving_mask")]),
294+
295+
(coregistration_wf, coreg_report_wf, [
296+
("outputnode.t1w_mask", "inputnode.in_mask"),
297+
("outputnode.t2w_preproc", "inputnode.t2w_preproc")]),
298+
277299
(anat_seg_wf, outputnode, [
278300
("outputnode.anat_dseg", "anat_dseg"),
279-
("outputnode.anat_tpms", "anat_tpms"),
301+
("outputnode.anat_tpms", "anat_tpms")]),
302+
(anat_seg_wf, anat_derivatives_wf, [
303+
("outputnode.anat_dseg", "inputnode.t1w_dseg"),
304+
("outputnode.anat_tpms", "inputnode.t1w_tpms"),
280305
]),
281306
(anat_seg_wf, anat_norm_wf, [
282307
("outputnode.anat_dseg", "inputnode.moving_segmentation"),
283-
("outputnode.anat_tpms", "inputnode.moving_tpms"),
284-
]),
308+
("outputnode.anat_tpms", "inputnode.moving_tpms")]),
309+
310+
(anat_norm_wf, anat_reports_wf, [("poutputnode.template", "inputnode.template")]),
285311
(anat_norm_wf, outputnode, [
286312
("poutputnode.standardized", "std_preproc"),
287313
("poutputnode.std_mask", "std_mask"),
288314
("poutputnode.std_dseg", "std_dseg"),
289315
("poutputnode.std_tpms", "std_tpms"),
290316
("outputnode.template", "template"),
291317
("outputnode.anat2std_xfm", "anat2std_xfm"),
292-
("outputnode.std2anat_xfm", "std2anat_xfm"),
293-
]),
294-
(inputnode, anat_norm_wf, [
295-
(("t1w", fix_multi_source_name), "inputnode.orig_t1w"), # anat_validate? not used...
296-
]),
297-
(t1w_preproc_wf, anat_norm_wf, [
298-
("outputnode.t1w_preproc", "inputnode.moving_image"),
299-
("outputnode.t1w_mask", "inputnode.moving_mask"),
300-
]),
301-
(t1w_preproc_wf, anat_derivatives_wf, [
318+
("outputnode.std2anat_xfm", "std2anat_xfm")]),
319+
(anat_norm_wf, anat_derivatives_wf, [
320+
("outputnode.template", "inputnode.template"),
321+
("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"),
322+
("outputnode.std2anat_xfm", "inputnode.std2anat_xfm")]),
323+
324+
(coregistration_wf, anat_seg_wf, [("outputnode.t1w_brain", "inputnode.anat_brain")]),
325+
(coregistration_wf, anat_derivatives_wf, [
302326
("outputnode.t1w_mask", "inputnode.t1w_mask"),
303327
("outputnode.t1w_preproc", "inputnode.t1w_preproc"),
304-
]),
305-
(coregistration_wf, anat_derivatives_wf, [
306-
("outputnode.t2w_preproc", "inputnode.t2w_preproc")
328+
("outputnode.t2w_preproc", "inputnode.t2w_preproc"),
307329
]),
308-
(t1w_preproc_wf, outputnode, [
330+
(coregistration_wf, outputnode, [
309331
("outputnode.t1w_preproc", "anat_preproc"),
310332
("outputnode.t1w_brain", "anat_brain"),
311333
("outputnode.t1w_mask", "anat_mask"),
312334
]),
313-
])
314335

315-
if not precomp_aseg:
316-
wf.connect([
317-
(t1w_preproc_wf, anat_seg_wf, [("outputnode.t1w_brain", "inputnode.anat_brain")]),
318-
])
319-
wf.connect([
320-
(inputnode, t2w_template_wf, [("t2w", "inputnode.in_files")]),
336+
(t1w_template_wf, anat_reports_wf, [
337+
("outputnode.out_report", "inputnode.t1w_conform_report")]),
338+
(outputnode, anat_reports_wf, [
339+
("anat_preproc", "inputnode.t1w_preproc"),
340+
("anat_mask", "inputnode.t1w_mask"),
341+
("anat_dseg", "inputnode.t1w_dseg"),
342+
("std_preproc", "inputnode.std_t1w"),
343+
("std_mask", "inputnode.std_mask"),
344+
]),
321345
])
346+
322347
if precomp_mask:
348+
# Ensure the mask is conformed along with the T1w
349+
t1w_template_wf.inputs.inputnode.anat_mask = precomp_mask
323350
wf.connect([
324-
(t1w_template_wf, precomp_mask_wf, [
325-
("outputnode.out_file", "inputnode.t1w"),
326-
]),
327-
(t2w_template_wf, sdc_brain_extraction_wf, [
328-
("outputnode.out_file", "inputnode.in_file"),
329-
]),
330-
(sdc_brain_extraction_wf, coregistration_wf, [
331-
("outputnode.out_file", "inputnode.in_t2w_preproc"),
332-
("outputnode.out_mask", "inputnode.in_mask"),
333-
("outputnode.out_probseg", "inputnode.in_probmap"),
334-
]),
351+
(t1w_template_wf, coregistration_wf, [("outputnode.anat_mask", "inputnode.in_mask")]),
335352
])
336353
else:
354+
# Run brain extraction on the T2w
355+
brain_extraction_wf = init_infant_brain_extraction_wf(
356+
age_months=age_months,
357+
ants_affine_init=ants_affine_init,
358+
skull_strip_template=skull_strip_template.space,
359+
template_specs=skull_strip_template.spec,
360+
omp_nthreads=omp_nthreads,
361+
sloppy=sloppy,
362+
debug="registration" in config.execution.debug,
363+
)
364+
337365
wf.connect([
338-
(t2w_template_wf, brain_extraction_wf, [
339-
("outputnode.out_file", "inputnode.in_t2w"),
340-
]),
366+
(t1w_preproc_wf, brain_extraction_wf, [
367+
("outputnode.anat_preproc", "inputnode.in_t1w")]),
368+
(t2w_preproc_wf, brain_extraction_wf, [
369+
("outputnode.anat_preproc", "inputnode.in_t2w")]),
341370
(brain_extraction_wf, coregistration_wf, [
342-
("outputnode.t2w_preproc", "inputnode.in_t2w_preproc"),
371+
("outputnode.t2w_preproc", "inputnode.in_t2w"),
343372
("outputnode.out_mask", "inputnode.in_mask"),
344-
("outputnode.out_probmap", "inputnode.in_probmap"),
345-
]),
373+
("outputnode.out_probmap", "inputnode.in_probmap")]),
346374
])
347-
wf.connect([
348-
(t1w_template_wf, coregistration_wf, [
349-
("outputnode.out_file", "inputnode.in_t1w"),
350-
]),
351375

352-
(inputnode, coreg_report_wf, [
353-
("t1w", "inputnode.source_file"),
354-
]),
355-
(t1w_preproc_wf, coreg_report_wf, [
356-
("outputnode.t1w_preproc", "inputnode.t1w_preproc"),
357-
("outputnode.t1w_mask", "inputnode.in_mask"),
358-
]),
359-
(coregistration_wf, coreg_report_wf, [
360-
("outputnode.t2w_preproc", "inputnode.t2w_preproc")
361-
]),
362-
])
363-
364-
wf.connect([
365-
# reports
366-
(inputnode, anat_reports_wf, [
367-
("t1w", "inputnode.source_file"),
368-
]),
369-
(outputnode, anat_reports_wf, [
370-
("anat_preproc", "inputnode.t1w_preproc"),
371-
("anat_mask", "inputnode.t1w_mask"),
372-
("anat_dseg", "inputnode.t1w_dseg"),
373-
("std_preproc", "inputnode.std_t1w"),
374-
("std_mask", "inputnode.std_mask"),
375-
]),
376-
(t1w_template_wf, anat_reports_wf, [
377-
("outputnode.out_report", "inputnode.t1w_conform_report"),
378-
]),
379-
(anat_norm_wf, anat_reports_wf, [
380-
("poutputnode.template", "inputnode.template"),
381-
]),
382-
# derivatives
383-
(t1w_template_wf, anat_derivatives_wf, [
384-
("outputnode.valid_list", "inputnode.source_files"),
385-
("outputnode.realign_xfms", "inputnode.t1w_ref_xfms"),
386-
]),
387-
(t2w_template_wf, anat_derivatives_wf, [
388-
("outputnode.valid_list", "inputnode.t2w_source_files"),
389-
]),
390-
(anat_norm_wf, anat_derivatives_wf, [
391-
("outputnode.template", "inputnode.template"),
392-
("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"),
393-
("outputnode.std2anat_xfm", "inputnode.std2anat_xfm"),
394-
]),
395-
(anat_ribbon_wf, anat_derivatives_wf, [
396-
("outputnode.anat_ribbon", "inputnode.anat_ribbon"),
397-
]),
398-
(anat_seg_wf, anat_derivatives_wf, [
399-
("outputnode.anat_dseg", "inputnode.t1w_dseg"),
400-
("outputnode.anat_tpms", "inputnode.t1w_tpms"),
401-
]),
402-
])
403-
# fmt:on
376+
if precomp_aseg:
377+
# Ensure the segmentation is conformed along with the T1w
378+
t1w_template_wf.inputs.inputnode.anat_aseg = precomp_aseg
379+
wf.connect([
380+
(t1w_template_wf, anat_seg_wf, [("outputnode.anat_aseg", "inputnode.anat_aseg")]),
381+
])
404382

405383
if not freesurfer:
406384
return wf
@@ -419,20 +397,23 @@ def init_infant_anat_wf(
419397
use_aseg=use_aseg,
420398
)
421399

400+
# Anatomical ribbon file using HCP signed-distance volume method
401+
anat_ribbon_wf = init_anat_ribbon_wf()
402+
422403
# fmt:off
423404
wf.connect([
424405
(inputnode, surface_recon_wf, [
425406
("subject_id", "inputnode.subject_id"),
426-
("subjects_dir", "inputnode.subjects_dir"),
427-
("t2w", "inputnode.t2w"),
428-
]),
407+
("subjects_dir", "inputnode.subjects_dir")]),
408+
(t2w_preproc_wf, surface_recon_wf, [
409+
("outputnode.anat_preproc", "inputnode.t2w")]),
429410
(anat_seg_wf, surface_recon_wf, [
430411
("outputnode.anat_aseg", "inputnode.ants_segs"),
431412
]),
432413
(t1w_template_wf, surface_recon_wf, [
433414
("outputnode.out_file", "inputnode.t1w"),
434415
]),
435-
(t1w_preproc_wf, surface_recon_wf, [
416+
(coregistration_wf, surface_recon_wf, [
436417
("outputnode.t1w_brain", "inputnode.skullstripped_t1"),
437418
("outputnode.t1w_preproc", "inputnode.corrected_t1"),
438419
]),
@@ -446,7 +427,7 @@ def init_infant_anat_wf(
446427
("outputnode.out_aparc", "anat_aparc"),
447428
("outputnode.out_aseg", "anat_aseg"),
448429
]),
449-
(t1w_preproc_wf, anat_ribbon_wf, [
430+
(coregistration_wf, anat_ribbon_wf, [
450431
("outputnode.t1w_mask", "inputnode.t1w_mask"),
451432
]),
452433
(surface_recon_wf, anat_ribbon_wf, [
@@ -455,6 +436,9 @@ def init_infant_anat_wf(
455436
(anat_ribbon_wf, outputnode, [
456437
("outputnode.anat_ribbon", "anat_ribbon")
457438
]),
439+
(anat_ribbon_wf, anat_derivatives_wf, [
440+
("outputnode.anat_ribbon", "inputnode.anat_ribbon"),
441+
]),
458442
(surface_recon_wf, anat_reports_wf, [
459443
("outputnode.subject_id", "inputnode.subject_id"),
460444
("outputnode.subjects_dir", "inputnode.subjects_dir"),

0 commit comments

Comments
 (0)