Skip to content

Commit 4985ac0

Browse files
committed
FIX: Generate composite transform when using multi-step registration
1 parent 93d1e73 commit 4985ac0

File tree

2 files changed

+52
-40
lines changed

2 files changed

+52
-40
lines changed

nibabies/workflows/anatomical/fit.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,6 @@ def init_infant_anat_fit_wf(
991991
('anat2std_xfm', 'inputnode.anat2std_xfm'),
992992
('std2anat_xfm', 'inputnode.std2anat_xfm'),
993993
]),
994-
(anat_buffer, concat_reg_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
995994
(sourcefile_buffer, ds_concat_reg_wf, [
996995
('anat_source_files', 'inputnode.source_files')
997996
]),
@@ -1909,7 +1908,6 @@ def init_infant_single_anat_fit_wf(
19091908
('anat2std_xfm', 'inputnode.anat2std_xfm'),
19101909
('std2anat_xfm', 'inputnode.std2anat_xfm'),
19111910
]),
1912-
(anat_buffer, concat_reg_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
19131911
(sourcefile_buffer, ds_concat_reg_wf, [
19141912
('anat_source_files', 'inputnode.source_files')
19151913
]),

nibabies/workflows/anatomical/registration.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@
1515
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
1616
from smriprep.workflows.fit.registration import (
1717
TemplateDesc,
18-
TemplateFlowSelect,
1918
_fmt_cohort,
2019
get_metadata,
2120
tf_ver,
2221
)
2322

24-
from nibabies.config import DEFAULT_MEMORY_MIN_GB
25-
from nibabies.interfaces.patches import ConcatXFM
26-
2723

2824
def init_coregistration_wf(
2925
*,
@@ -312,7 +308,7 @@ def init_concat_registrations_wf(
312308
name='concat_registrations_wf',
313309
):
314310
"""
315-
Concatenate two transforms to produce a single transform, from native to ``template``.
311+
Concatenate two transforms to produce a single composite transform from native to template.
316312
317313
Parameters
318314
----------
@@ -347,6 +343,8 @@ def init_concat_registrations_wf(
347343
further use in downstream nodes.
348344
349345
"""
346+
from nibabies.interfaces.patches import CompositeTransformUtil
347+
350348
ntpls = len(templates)
351349
workflow = Workflow(name=name)
352350

@@ -384,9 +382,7 @@ def init_concat_registrations_wf(
384382
workflow.__desc__ += '.\n' if template == templates[-1] else ', '
385383

386384
inputnode = pe.Node(
387-
niu.IdentityInterface(
388-
fields=['template', 'anat_preproc', 'anat2std_xfm', 'intermediate', 'std2anat_xfm']
389-
),
385+
niu.IdentityInterface(fields=['template', 'intermediate', 'anat2std_xfm', 'std2anat_xfm']),
390386
name='inputnode',
391387
)
392388
inputnode.inputs.template = templates
@@ -413,29 +409,39 @@ def init_concat_registrations_wf(
413409
TemplateDesc(), run_without_submitting=True, iterfield='template', name='split_desc'
414410
)
415411

416-
tf_select = pe.MapNode(
417-
TemplateFlowSelect(resolution=1),
418-
name='tf_select',
419-
run_without_submitting=True,
420-
iterfield=['template', 'template_spec'],
412+
merge_anat2std = pe.Node(niu.Merge(2), name='merge_anat2std', run_without_submitting=True)
413+
merge_std2anat = merge_anat2std.clone('merge_std2anat')
414+
415+
disassemble_anat2std = pe.MapNode(
416+
CompositeTransformUtil(process='disassemble', output_prefix='anat2std'),
417+
iterfield=['in_file'],
418+
name='disassemble_anat2std',
421419
)
422420

423-
merge_anat2std = pe.MapNode(
424-
niu.Merge(2), name='merge_anat2std', iterfield=['in1', 'in2'], run_without_submitting=True
421+
disassemble_std2anat = pe.MapNode(
422+
CompositeTransformUtil(process='disassemble', output_prefix='std2anat', inverse=True),
423+
iterfield=['in_file'],
424+
name='disassemble_std2anat',
425+
)
426+
427+
order_anat2std_composites = pe.Node(
428+
niu.Function(function=_order_composites),
429+
name='order_anat2std_composites',
430+
)
431+
432+
order_std2anat_composites = pe.Node(
433+
niu.Function(function=_order_composites),
434+
name='order_std2anat_composites',
425435
)
426-
merge_std2anat = merge_anat2std.clone('merge_std2anat')
427436

428-
concat_anat2std = pe.MapNode(
429-
ConcatXFM(),
430-
name='concat_anat2std',
431-
mem_gb=DEFAULT_MEMORY_MIN_GB,
432-
iterfield=['transforms', 'reference_image'],
437+
assemble_anat2std = pe.Node(
438+
CompositeTransformUtil(process='assemble', out_file='anat2std.h5'),
439+
name='assemble_anat2std',
433440
)
434-
concat_std2anat = pe.MapNode(
435-
ConcatXFM(),
436-
name='concat_std2anat',
437-
mem_gb=DEFAULT_MEMORY_MIN_GB,
438-
iterfield=['transforms', 'reference_image'],
441+
442+
assemble_std2anat = pe.Node(
443+
CompositeTransformUtil(process='assemble', out_file='std2anat.h5'),
444+
name='assemble_std2anat',
439445
)
440446

441447
fmt_cohort = pe.MapNode(
@@ -446,24 +452,30 @@ def init_concat_registrations_wf(
446452
)
447453

448454
workflow.connect([
455+
# Template concatenation
449456
(inputnode, merge_anat2std, [('anat2std_xfm', 'in2')]),
450457
(inputnode, merge_std2anat, [('std2anat_xfm', 'in2')]),
451-
(inputnode, concat_std2anat, [('anat_preproc', 'reference_image')]),
452458
(inputnode, intermed_xfms, [('intermediate', 'intermediate')]),
453459
(inputnode, intermed_xfms, [('template', 'std')]),
454-
455460
(intermed_xfms, merge_anat2std, [('int2std_xfm', 'in1')]),
456461
(intermed_xfms, merge_std2anat, [('std2int_xfm', 'in1')]),
462+
(merge_anat2std, disassemble_anat2std, [('out', 'in_file')]),
463+
(merge_std2anat, disassemble_std2anat, [('out', 'in_file')]),
464+
(disassemble_anat2std, order_anat2std_composites, [
465+
('affine_transform', 'affines'),
466+
('displacement_field', 'displacements'),
467+
]),
468+
(disassemble_std2anat, order_std2anat_composites, [
469+
('affine_transform', 'affines'),
470+
('displacement_field', 'displacements'),
471+
]),
472+
(order_anat2std_composites, assemble_anat2std, [('out', 'in_file')]),
473+
(order_std2anat_composites, assemble_std2anat, [('out', 'in_file')]),
474+
(assemble_anat2std, outputnode, [('out_file', 'anat2std_xfm')]),
475+
(assemble_std2anat, outputnode, [('out_file', 'std2anat_xfm')]),
457476

458-
(merge_anat2std, concat_anat2std, [('out', 'transforms')]),
459-
(merge_std2anat, concat_std2anat, [('out', 'transforms')]),
460-
477+
# Template name wrangling
461478
(inputnode, split_desc, [('template', 'template')]),
462-
(split_desc, tf_select, [
463-
('name', 'template'),
464-
('spec', 'template_spec'),
465-
]),
466-
(tf_select, concat_anat2std, [('t1w_file', 'reference_image')]),
467479
(split_desc, fmt_cohort, [
468480
('name', 'template'),
469481
('spec', 'spec'),
@@ -472,8 +484,6 @@ def init_concat_registrations_wf(
472484
('template', 'template'),
473485
('spec', 'template_spec'),
474486
]),
475-
(concat_anat2std, outputnode, [('out_xfm', 'anat2std_xfm')]),
476-
(concat_std2anat, outputnode, [('out_xfm', 'std2anat_xfm')]),
477487
]) # fmt:skip
478488

479489
return workflow
@@ -507,3 +517,7 @@ def _load_intermediate_xfms(intermediate, std):
507517
)
508518

509519
return int2std, std2int
520+
521+
522+
def _order_composites(affines, displacements):
523+
return [affines[0], displacements[0], affines[1], displacements[1]]

0 commit comments

Comments
 (0)