1515from niworkflows .interfaces .fixes import FixHeaderApplyTransforms as ApplyTransforms
1616from smriprep .workflows .fit .registration import (
1717 TemplateDesc ,
18- TemplateFlowSelect ,
1918 _fmt_cohort ,
2019 get_metadata ,
2120 tf_ver ,
2221)
2322
24- from nibabies .config import DEFAULT_MEMORY_MIN_GB
25- from nibabies .interfaces .patches import ConcatXFM
26-
2723
2824def init_coregistration_wf (
2925 * ,
@@ -312,7 +308,7 @@ def init_concat_registrations_wf(
312308 name = 'concat_registrations_wf' ,
313309):
314310 """
315- Concatenate two transforms to produce a single transform, from native to `` template`` .
311+ Concatenate two transforms to produce a single composite transform from native to template.
316312
317313 Parameters
318314 ----------
@@ -347,6 +343,8 @@ def init_concat_registrations_wf(
347343 further use in downstream nodes.
348344
349345 """
346+ from nibabies .interfaces .patches import CompositeTransformUtil
347+
350348 ntpls = len (templates )
351349 workflow = Workflow (name = name )
352350
@@ -384,9 +382,7 @@ def init_concat_registrations_wf(
384382 workflow .__desc__ += '.\n ' if template == templates [- 1 ] else ', '
385383
386384 inputnode = pe .Node (
387- niu .IdentityInterface (
388- fields = ['template' , 'anat_preproc' , 'anat2std_xfm' , 'intermediate' , 'std2anat_xfm' ]
389- ),
385+ niu .IdentityInterface (fields = ['template' , 'intermediate' , 'anat2std_xfm' , 'std2anat_xfm' ]),
390386 name = 'inputnode' ,
391387 )
392388 inputnode .inputs .template = templates
@@ -413,29 +409,39 @@ def init_concat_registrations_wf(
413409 TemplateDesc (), run_without_submitting = True , iterfield = 'template' , name = 'split_desc'
414410 )
415411
416- tf_select = pe .MapNode (
417- TemplateFlowSelect (resolution = 1 ),
418- name = 'tf_select' ,
419- run_without_submitting = True ,
420- iterfield = ['template' , 'template_spec' ],
412+ merge_anat2std = pe .Node (niu .Merge (2 ), name = 'merge_anat2std' , run_without_submitting = True )
413+ merge_std2anat = merge_anat2std .clone ('merge_std2anat' )
414+
415+ disassemble_anat2std = pe .MapNode (
416+ CompositeTransformUtil (process = 'disassemble' , output_prefix = 'anat2std' ),
417+ iterfield = ['in_file' ],
418+ name = 'disassemble_anat2std' ,
421419 )
422420
423- merge_anat2std = pe .MapNode (
424- niu .Merge (2 ), name = 'merge_anat2std' , iterfield = ['in1' , 'in2' ], run_without_submitting = True
421+ disassemble_std2anat = pe .MapNode (
422+ CompositeTransformUtil (process = 'disassemble' , output_prefix = 'std2anat' , inverse = True ),
423+ iterfield = ['in_file' ],
424+ name = 'disassemble_std2anat' ,
425+ )
426+
427+ order_anat2std_composites = pe .Node (
428+ niu .Function (function = _order_composites ),
429+ name = 'order_anat2std_composites' ,
430+ )
431+
432+ order_std2anat_composites = pe .Node (
433+ niu .Function (function = _order_composites ),
434+ name = 'order_std2anat_composites' ,
425435 )
426- merge_std2anat = merge_anat2std .clone ('merge_std2anat' )
427436
428- concat_anat2std = pe .MapNode (
429- ConcatXFM (),
430- name = 'concat_anat2std' ,
431- mem_gb = DEFAULT_MEMORY_MIN_GB ,
432- iterfield = ['transforms' , 'reference_image' ],
437+ assemble_anat2std = pe .Node (
438+ CompositeTransformUtil (process = 'assemble' , out_file = 'anat2std.h5' ),
439+ name = 'assemble_anat2std' ,
433440 )
434- concat_std2anat = pe .MapNode (
435- ConcatXFM (),
436- name = 'concat_std2anat' ,
437- mem_gb = DEFAULT_MEMORY_MIN_GB ,
438- iterfield = ['transforms' , 'reference_image' ],
441+
442+ assemble_std2anat = pe .Node (
443+ CompositeTransformUtil (process = 'assemble' , out_file = 'std2anat.h5' ),
444+ name = 'assemble_std2anat' ,
439445 )
440446
441447 fmt_cohort = pe .MapNode (
@@ -446,24 +452,30 @@ def init_concat_registrations_wf(
446452 )
447453
448454 workflow .connect ([
455+ # Template concatenation
449456 (inputnode , merge_anat2std , [('anat2std_xfm' , 'in2' )]),
450457 (inputnode , merge_std2anat , [('std2anat_xfm' , 'in2' )]),
451- (inputnode , concat_std2anat , [('anat_preproc' , 'reference_image' )]),
452458 (inputnode , intermed_xfms , [('intermediate' , 'intermediate' )]),
453459 (inputnode , intermed_xfms , [('template' , 'std' )]),
454-
455460 (intermed_xfms , merge_anat2std , [('int2std_xfm' , 'in1' )]),
456461 (intermed_xfms , merge_std2anat , [('std2int_xfm' , 'in1' )]),
462+ (merge_anat2std , disassemble_anat2std , [('out' , 'in_file' )]),
463+ (merge_std2anat , disassemble_std2anat , [('out' , 'in_file' )]),
464+ (disassemble_anat2std , order_anat2std_composites , [
465+ ('affine_transform' , 'affines' ),
466+ ('displacement_field' , 'displacements' ),
467+ ]),
468+ (disassemble_std2anat , order_std2anat_composites , [
469+ ('affine_transform' , 'affines' ),
470+ ('displacement_field' , 'displacements' ),
471+ ]),
472+ (order_anat2std_composites , assemble_anat2std , [('out' , 'in_file' )]),
473+ (order_std2anat_composites , assemble_std2anat , [('out' , 'in_file' )]),
474+ (assemble_anat2std , outputnode , [('out_file' , 'anat2std_xfm' )]),
475+ (assemble_std2anat , outputnode , [('out_file' , 'std2anat_xfm' )]),
457476
458- (merge_anat2std , concat_anat2std , [('out' , 'transforms' )]),
459- (merge_std2anat , concat_std2anat , [('out' , 'transforms' )]),
460-
477+ # Template name wrangling
461478 (inputnode , split_desc , [('template' , 'template' )]),
462- (split_desc , tf_select , [
463- ('name' , 'template' ),
464- ('spec' , 'template_spec' ),
465- ]),
466- (tf_select , concat_anat2std , [('t1w_file' , 'reference_image' )]),
467479 (split_desc , fmt_cohort , [
468480 ('name' , 'template' ),
469481 ('spec' , 'spec' ),
@@ -472,8 +484,6 @@ def init_concat_registrations_wf(
472484 ('template' , 'template' ),
473485 ('spec' , 'template_spec' ),
474486 ]),
475- (concat_anat2std , outputnode , [('out_xfm' , 'anat2std_xfm' )]),
476- (concat_std2anat , outputnode , [('out_xfm' , 'std2anat_xfm' )]),
477487 ]) # fmt:skip
478488
479489 return workflow
@@ -507,3 +517,7 @@ def _load_intermediate_xfms(intermediate, std):
507517 )
508518
509519 return int2std , std2int
520+
521+
522+ def _order_composites (affines , displacements ):
523+ return [affines [0 ], displacements [0 ], affines [1 ], displacements [1 ]]
0 commit comments