Skip to content

Commit f327dd6

Browse files
committed
RF: Perform transform ordering when disassembling
1 parent f5d97dd commit f327dd6

File tree

2 files changed

+62
-37
lines changed

2 files changed

+62
-37
lines changed

nibabies/interfaces/patches.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from nipype.interfaces.ants.registration import (
1111
CompositeTransformUtilInputSpec as _CompositeTransformUtilInputSpec,
1212
)
13+
from nipype.interfaces.ants.registration import (
14+
CompositeTransformUtilOutputSpec as _CompositeTransformUtilOutputSpec,
15+
)
1316
from nipype.interfaces.base import File, InputMultiObject, TraitedSpec, traits
1417

1518

@@ -114,35 +117,69 @@ def _list_outputs(self):
114117

115118

116119
class CompositeTransformUtilInputSpec(_CompositeTransformUtilInputSpec):
117-
inverse = traits.Bool(
118-
False,
120+
order_transforms = traits.Bool(
121+
True,
119122
usedefault=True,
120-
desc='When disassembling an inverse component transform, the indexing will be reversed.',
123+
desc='Order disassembled transforms into [Affine, Displacement] pairs.',
121124
)
122125

123126

127+
class CompositeTransformUtilOutputSpec(_CompositeTransformUtilOutputSpec):
128+
out_transforms = traits.List(desc='list of transform components')
129+
130+
124131
class CompositeTransformUtil(_CompositeTransformUtil):
125132
"""Outputs have changed in newer versions of ANTs."""
126133

127134
input_spec = CompositeTransformUtilInputSpec
135+
output_spec = CompositeTransformUtilOutputSpec
128136

129137
def _list_outputs(self):
130138
outputs = self.output_spec().get()
131139

132-
# Index may change depending on forward/inverse transform
140+
# Ordering may change depending on forward/inverse transform
133141
# Forward: <prefix>_00_AffineTransform.mat, <prefix>_01_DisplacementFieldTransform.nii.gz
134142
# Inverse: <prefix>_01_AffineTransform.mat, <prefix>_00_DisplacementFieldTransform.nii.gz
135-
idx = ['00', '01']
136-
if self.inputs.inverse:
137-
idx = idx[::-1]
138-
139143
if self.inputs.process == 'disassemble':
140-
outputs['affine_transform'] = Path(
141-
f'{self.inputs.output_prefix}_{idx[0]}_AffineTransform.mat'
142-
).absolute()
143-
outputs['displacement_field'] = Path(
144-
f'{self.inputs.output_prefix}_{idx[1]}_DisplacementFieldTransform.nii.gz'
145-
).absolute()
144+
transforms = [
145+
str(Path(x).absolute())
146+
for x in sorted(Path().glob(f'{self.inputs.output_prefix}_*'))
147+
]
148+
149+
if self.inputs.order_transforms:
150+
transforms = _order_xfms(transforms)
151+
outputs['out_transforms'] = transforms
152+
153+
# Potentially could be more than one affine / displacement per composite transform...
154+
outputs['affine_transform'] = [
155+
x for x in transforms if 'AffineTransform' in Path(x).name
156+
][0]
157+
outputs['displacement_field'] = [
158+
x for x in transforms if 'DisplacementFieldTransform' in Path(x).name
159+
][0]
146160
elif self.inputs.process == 'assemble':
147161
outputs['out_file'] = Path(self.inputs.out_file).absolute()
148162
return outputs
163+
164+
165+
def _order_xfms(vals):
166+
"""
167+
Assumes [affine, displacement] or [displacement, affine] transform pairs.
168+
169+
>>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat'])
170+
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz']
171+
172+
>>> _order_xfms(['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz'])
173+
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz']
174+
175+
>>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat', \
176+
'AffineTransform.mat'])
177+
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz', 'AffineTransform.mat']
178+
"""
179+
for i in range(0, len(vals) - 1, 2):
180+
if (
181+
'DisplacementFieldTransform' in Path(vals[i]).name
182+
and 'AffineTransform' in Path(vals[i + 1]).name
183+
):
184+
vals[i], vals[i + 1] = vals[i + 1], vals[i]
185+
return vals

nibabies/workflows/anatomical/registration.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -419,26 +419,24 @@ def init_concat_registrations_wf(
419419
)
420420

421421
disassemble_std2anat = pe.MapNode(
422-
CompositeTransformUtil(process='disassemble', output_prefix='std2anat', inverse=True),
422+
CompositeTransformUtil(process='disassemble', output_prefix='std2anat'),
423423
iterfield=['in_file'],
424424
name='disassemble_std2anat',
425425
)
426426

427-
order_anat2std_composites = pe.Node(
428-
niu.Function(function=_order_composites),
429-
name='order_anat2std_composites',
427+
merge_anat2std_composites = pe.Node(
428+
niu.Merge(1, ravel_inputs=True),
429+
name='merge_anat2std_composites',
430430
)
431-
432-
order_std2anat_composites = pe.Node(
433-
niu.Function(function=_order_composites),
434-
name='order_std2anat_composites',
431+
merge_std2anat_composites = pe.Node(
432+
niu.Merge(1, ravel_inputs=True),
433+
name='merge_std2anat_composites',
435434
)
436435

437436
assemble_anat2std = pe.Node(
438437
CompositeTransformUtil(process='assemble', out_file='anat2std.h5'),
439438
name='assemble_anat2std',
440439
)
441-
442440
assemble_std2anat = pe.Node(
443441
CompositeTransformUtil(process='assemble', out_file='std2anat.h5'),
444442
name='assemble_std2anat',
@@ -461,16 +459,10 @@ def init_concat_registrations_wf(
461459
(intermed_xfms, merge_std2anat, [('std2int_xfm', 'in1')]),
462460
(merge_anat2std, disassemble_anat2std, [('out', 'in_file')]),
463461
(merge_std2anat, disassemble_std2anat, [('out', 'in_file')]),
464-
(disassemble_anat2std, order_anat2std_composites, [
465-
('affine_transform', 'affines'),
466-
('displacement_field', 'displacements'),
467-
]),
468-
(disassemble_std2anat, order_std2anat_composites, [
469-
('affine_transform', 'affines'),
470-
('displacement_field', 'displacements'),
471-
]),
472-
(order_anat2std_composites, assemble_anat2std, [('out', 'in_file')]),
473-
(order_std2anat_composites, assemble_std2anat, [('out', 'in_file')]),
462+
(disassemble_anat2std, merge_anat2std_composites, [('out_transforms', 'in1')]),
463+
(disassemble_std2anat, merge_std2anat_composites, [('out_transforms', 'in1')]),
464+
(merge_anat2std_composites, assemble_anat2std, [('out', 'in_file')]),
465+
(merge_std2anat_composites, assemble_std2anat, [('out', 'in_file')]),
474466
(assemble_anat2std, outputnode, [('out_file', 'anat2std_xfm')]),
475467
(assemble_std2anat, outputnode, [('out_file', 'std2anat_xfm')]),
476468

@@ -517,7 +509,3 @@ def _load_intermediate_xfms(intermediate, std):
517509
)
518510

519511
return int2std, std2int
520-
521-
522-
def _order_composites(affines, displacements):
523-
return [affines[0], displacements[0], affines[1], displacements[1]]

0 commit comments

Comments
 (0)