1515from niworkflows .interfaces .fixes import FixHeaderApplyTransforms as ApplyTransforms
1616from smriprep .workflows .fit .registration import (
1717 TemplateDesc ,
18- TemplateFlowSelect ,
1918 _fmt_cohort ,
2019 get_metadata ,
2120 tf_ver ,
2221)
2322
24- from nibabies .config import DEFAULT_MEMORY_MIN_GB
25- from nibabies .interfaces .patches import ConcatXFM
26-
2723
2824def init_coregistration_wf (
2925 * ,
@@ -312,7 +308,7 @@ def init_concat_registrations_wf(
312308 name = 'concat_registrations_wf' ,
313309):
314310 """
315- Concatenate two transforms to produce a single transform, from native to `` template`` .
311+ Concatenate two transforms to produce a single composite transform from native to template.
316312
317313 Parameters
318314 ----------
@@ -347,6 +343,8 @@ def init_concat_registrations_wf(
347343 further use in downstream nodes.
348344
349345 """
346+ from nibabies .interfaces .patches import CompositeTransformUtil
347+
350348 ntpls = len (templates )
351349 workflow = Workflow (name = name )
352350
@@ -384,9 +382,7 @@ def init_concat_registrations_wf(
384382 workflow .__desc__ += '.\n ' if template == templates [- 1 ] else ', '
385383
386384 inputnode = pe .Node (
387- niu .IdentityInterface (
388- fields = ['template' , 'anat_preproc' , 'anat2std_xfm' , 'intermediate' , 'std2anat_xfm' ]
389- ),
385+ niu .IdentityInterface (fields = ['template' , 'intermediate' , 'anat2std_xfm' , 'std2anat_xfm' ]),
390386 name = 'inputnode' ,
391387 )
392388 inputnode .inputs .template = templates
@@ -413,29 +409,37 @@ def init_concat_registrations_wf(
413409 TemplateDesc (), run_without_submitting = True , iterfield = 'template' , name = 'split_desc'
414410 )
415411
416- tf_select = pe .MapNode (
417- TemplateFlowSelect (resolution = 1 ),
418- name = 'tf_select' ,
419- run_without_submitting = True ,
420- iterfield = ['template' , 'template_spec' ],
412+ merge_anat2std = pe .Node (niu .Merge (2 ), name = 'merge_anat2std' , run_without_submitting = True )
413+ merge_std2anat = merge_anat2std .clone ('merge_std2anat' )
414+
415+ disassemble_anat2std = pe .MapNode (
416+ CompositeTransformUtil (process = 'disassemble' , output_prefix = 'anat2std' ),
417+ iterfield = ['in_file' ],
418+ name = 'disassemble_anat2std' ,
421419 )
422420
423- merge_anat2std = pe .MapNode (
424- niu .Merge (2 ), name = 'merge_anat2std' , iterfield = ['in1' , 'in2' ], run_without_submitting = True
421+ disassemble_std2anat = pe .MapNode (
422+ CompositeTransformUtil (process = 'disassemble' , output_prefix = 'std2anat' ),
423+ iterfield = ['in_file' ],
424+ name = 'disassemble_std2anat' ,
425425 )
426- merge_std2anat = merge_anat2std .clone ('merge_std2anat' )
427426
428- concat_anat2std = pe .MapNode (
429- ConcatXFM (),
430- name = 'concat_anat2std' ,
431- mem_gb = DEFAULT_MEMORY_MIN_GB ,
432- iterfield = ['transforms' , 'reference_image' ],
427+ merge_anat2std_composites = pe .Node (
428+ niu .Merge (1 , ravel_inputs = True ),
429+ name = 'merge_anat2std_composites' ,
433430 )
434- concat_std2anat = pe .MapNode (
435- ConcatXFM (),
436- name = 'concat_std2anat' ,
437- mem_gb = DEFAULT_MEMORY_MIN_GB ,
438- iterfield = ['transforms' , 'reference_image' ],
431+ merge_std2anat_composites = pe .Node (
432+ niu .Merge (1 , ravel_inputs = True ),
433+ name = 'merge_std2anat_composites' ,
434+ )
435+
436+ assemble_anat2std = pe .Node (
437+ CompositeTransformUtil (process = 'assemble' , out_file = 'anat2std.h5' ),
438+ name = 'assemble_anat2std' ,
439+ )
440+ assemble_std2anat = pe .Node (
441+ CompositeTransformUtil (process = 'assemble' , out_file = 'std2anat.h5' ),
442+ name = 'assemble_std2anat' ,
439443 )
440444
441445 fmt_cohort = pe .MapNode (
@@ -446,24 +450,24 @@ def init_concat_registrations_wf(
446450 )
447451
448452 workflow .connect ([
453+ # Template concatenation
449454 (inputnode , merge_anat2std , [('anat2std_xfm' , 'in2' )]),
450455 (inputnode , merge_std2anat , [('std2anat_xfm' , 'in2' )]),
451- (inputnode , concat_std2anat , [('anat_preproc' , 'reference_image' )]),
452456 (inputnode , intermed_xfms , [('intermediate' , 'intermediate' )]),
453457 (inputnode , intermed_xfms , [('template' , 'std' )]),
454-
455458 (intermed_xfms , merge_anat2std , [('int2std_xfm' , 'in1' )]),
456459 (intermed_xfms , merge_std2anat , [('std2int_xfm' , 'in1' )]),
457-
458- (merge_anat2std , concat_anat2std , [('out' , 'transforms' )]),
459- (merge_std2anat , concat_std2anat , [('out' , 'transforms' )]),
460-
460+ (merge_anat2std , disassemble_anat2std , [('out' , 'in_file' )]),
461+ (merge_std2anat , disassemble_std2anat , [('out' , 'in_file' )]),
462+ (disassemble_anat2std , merge_anat2std_composites , [('out_transforms' , 'in1' )]),
463+ (disassemble_std2anat , merge_std2anat_composites , [('out_transforms' , 'in1' )]),
464+ (merge_anat2std_composites , assemble_anat2std , [('out' , 'in_file' )]),
465+ (merge_std2anat_composites , assemble_std2anat , [('out' , 'in_file' )]),
466+ (assemble_anat2std , outputnode , [('out_file' , 'anat2std_xfm' )]),
467+ (assemble_std2anat , outputnode , [('out_file' , 'std2anat_xfm' )]),
468+
469+ # Template name wrangling
461470 (inputnode , split_desc , [('template' , 'template' )]),
462- (split_desc , tf_select , [
463- ('name' , 'template' ),
464- ('spec' , 'template_spec' ),
465- ]),
466- (tf_select , concat_anat2std , [('t1w_file' , 'reference_image' )]),
467471 (split_desc , fmt_cohort , [
468472 ('name' , 'template' ),
469473 ('spec' , 'spec' ),
@@ -472,8 +476,6 @@ def init_concat_registrations_wf(
472476 ('template' , 'template' ),
473477 ('spec' , 'template_spec' ),
474478 ]),
475- (concat_anat2std , outputnode , [('out_xfm' , 'anat2std_xfm' )]),
476- (concat_std2anat , outputnode , [('out_xfm' , 'std2anat_xfm' )]),
477479 ]) # fmt:skip
478480
479481 return workflow
0 commit comments