Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 1 addition & 39 deletions nibabies/interfaces/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from nipype.interfaces.ants.registration import (
CompositeTransformUtil as _CompositeTransformUtil,
)
from nipype.interfaces.ants.registration import (
CompositeTransformUtilInputSpec as _CompositeTransformUtilInputSpec,
)
from nipype.interfaces.ants.registration import (
CompositeTransformUtilOutputSpec as _CompositeTransformUtilOutputSpec,
)
Expand Down Expand Up @@ -116,22 +113,13 @@ def _list_outputs(self):
return outputs


class CompositeTransformUtilInputSpec(_CompositeTransformUtilInputSpec):
order_transforms = traits.Bool(
True,
usedefault=True,
desc='Order disassembled transforms into [Affine, Displacement] pairs.',
)


class CompositeTransformUtilOutputSpec(_CompositeTransformUtilOutputSpec):
out_transforms = traits.List(desc='list of transform components')
out_transforms = traits.List(desc='list of ordered transform components')


class CompositeTransformUtil(_CompositeTransformUtil):
"""Outputs have changed in newer versions of ANTs."""

input_spec = CompositeTransformUtilInputSpec
output_spec = CompositeTransformUtilOutputSpec

def _list_outputs(self):
Expand All @@ -145,9 +133,6 @@ def _list_outputs(self):
str(Path(x).absolute())
for x in sorted(Path().glob(f'{self.inputs.output_prefix}_*'))
]

if self.inputs.order_transforms:
transforms = _order_xfms(transforms)
outputs['out_transforms'] = transforms

# Potentially could be more than one affine / displacement per composite transform...
Expand All @@ -160,26 +145,3 @@ def _list_outputs(self):
elif self.inputs.process == 'assemble':
outputs['out_file'] = Path(self.inputs.out_file).absolute()
return outputs


def _order_xfms(vals):
"""
Assumes [affine, displacement] or [displacement, affine] transform pairs.

>>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat'])
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz']

>>> _order_xfms(['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz'])
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz']

>>> _order_xfms(['DisplacementFieldTransform.nii.gz', 'AffineTransform.mat', \
'AffineTransform.mat'])
['AffineTransform.mat', 'DisplacementFieldTransform.nii.gz', 'AffineTransform.mat']
"""
for i in range(0, len(vals) - 1, 2):
if (
'DisplacementFieldTransform' in Path(vals[i]).name
and 'AffineTransform' in Path(vals[i + 1]).name
):
vals[i], vals[i + 1] = vals[i + 1], vals[i]
return vals
19 changes: 0 additions & 19 deletions nibabies/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
if path.suffix == '.h5':
# Load as a TransformChain
xfm = nt.manip.load(path)
if len(xfm.transforms) == 4:
# MG: This behavior should be ported to nitransforms
xfm = nt.manip.TransformChain(reverse_pairs(xfm.transforms))
else:
xfm = nt.linear.load(path)
if inv:
Expand All @@ -35,19 +32,3 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
if chain is None:
chain = nt.Affine() # Identity
return chain


def reverse_pairs(arr: list) -> list:
"""
Reverse the order of pairs in a list.

>>> reverse_pairs([1, 2, 3, 4])
[3, 4, 1, 2]

>>> reverse_pairs([1, 2, 3, 4, 5, 6])
[5, 6, 3, 4, 1, 2]
"""
rev = []
for i in range(len(arr), 0, -2):
rev.extend(arr[i - 2 : i])
return rev
8 changes: 4 additions & 4 deletions nibabies/workflows/anatomical/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@
t2w_mask = precomputed.get('t2w_mask')
anat_mask = precomputed.get(f'{anat}_mask')
refine_mask = False
# T1w masking - define pre-emptively

Check failure on line 456 in nibabies/workflows/anatomical/fit.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

pre-emptively ==> preemptively
apply_t1w_mask = pe.Node(ApplyMask(), name='apply_t1w_mask')
apply_t2w_mask = apply_t1w_mask.clone(name='apply_t2w_mask')

Expand Down Expand Up @@ -988,8 +988,8 @@
(concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]),
(select_infant_mni, concat_reg_wf, [
('key', 'inputnode.intermediate'),
('anat2std_xfm', 'inputnode.anat2std_xfm'),
('std2anat_xfm', 'inputnode.std2anat_xfm'),
('anat2std_xfm', 'inputnode.anat2int_xfm'),
('std2anat_xfm', 'inputnode.int2anat_xfm'),
]),
(sourcefile_buffer, ds_concat_reg_wf, [
('anat_source_files', 'inputnode.source_files')
Expand Down Expand Up @@ -1905,8 +1905,8 @@
(concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]),
(select_infant_mni, concat_reg_wf, [
('key', 'inputnode.intermediate'),
('anat2std_xfm', 'inputnode.anat2std_xfm'),
('std2anat_xfm', 'inputnode.std2anat_xfm'),
('anat2std_xfm', 'inputnode.anat2int_xfm'),
('std2anat_xfm', 'inputnode.int2anat_xfm'),
]),
(sourcefile_buffer, ds_concat_reg_wf, [
('anat_source_files', 'inputnode.source_files')
Expand Down
134 changes: 97 additions & 37 deletions nibabies/workflows/anatomical/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,14 @@ def init_concat_registrations_wf(
workflow.__desc__ += '.\n' if template == templates[-1] else ', '

inputnode = pe.Node(
niu.IdentityInterface(fields=['template', 'intermediate', 'anat2std_xfm', 'std2anat_xfm']),
niu.IdentityInterface(
fields=[
'template', # template identifier (name[+cohort])
'intermediate', # intermediate space (name[+cohort])
'anat2int_xfm', # anatomical -> intermediate
'int2anat_xfm', # intermediate -> anatomical
]
),
name='inputnode',
)
inputnode.inputs.template = templates
Expand All @@ -401,68 +408,83 @@ def init_concat_registrations_wf(
),
name='intermed_xfms',
iterfield=['std'],
overwrite=True, # otherwise, cache hits but not guarantee files are present on reruns
run_without_submitting=True,
)

split_desc = pe.MapNode(
TemplateDesc(), run_without_submitting=True, iterfield='template', name='split_desc'
)

merge_anat2std = pe.Node(niu.Merge(2), name='merge_anat2std', run_without_submitting=True)
merge_std2anat = merge_anat2std.clone('merge_std2anat')
fmt_cohort = pe.MapNode(
niu.Function(function=_fmt_cohort, output_names=['template', 'spec']),
name='fmt_cohort',
run_without_submitting=True,
iterfield=['template', 'spec'],
)

disassemble_anat2std = pe.MapNode(
CompositeTransformUtil(process='disassemble', output_prefix='anat2std'),
iterfield=['in_file'],
name='disassemble_anat2std',
# Disassemble each composite transform individually for readability
dis_anat2int = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='anat2int'),
name='dis_anat2int',
)

disassemble_std2anat = pe.MapNode(
CompositeTransformUtil(process='disassemble', output_prefix='std2anat'),
iterfield=['in_file'],
name='disassemble_std2anat',
dis_int2std = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='int2std'),
name='dis_int2std',
)

merge_anat2std_composites = pe.Node(
niu.Merge(1, ravel_inputs=True),
name='merge_anat2std_composites',
dis_std2int = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='std2int'),
name='dis_std2int',
)
merge_std2anat_composites = pe.Node(
niu.Merge(1, ravel_inputs=True),
name='merge_std2anat_composites',

dis_int2anat = pe.Node(
CompositeTransformUtil(process='disassemble', output_prefix='int2anat'),
name='dis_int2anat',
)

order_anat2std = pe.Node(niu.Merge(4), name='order_anat2std', run_without_submitting=True)
order_std2anat = pe.Node(niu.Merge(4), name='order_std2anat', run_without_submitting=True)

assemble_anat2std = pe.Node(
CompositeTransformUtil(process='assemble', out_file='anat2std.h5'),
name='assemble_anat2std',
)
# https://github.com/ANTsX/ANTs/issues/1827
# Until CompositeTransformUtil accepts warps as first transform,
# Use SimpleITK to concatenate
assemble_std2anat = pe.Node(
CompositeTransformUtil(process='assemble', out_file='std2anat.h5'),
niu.Function(function=_create_inverse_composite, output_names=['out_file']),
name='assemble_std2anat',
)

fmt_cohort = pe.MapNode(
niu.Function(function=_fmt_cohort, output_names=['template', 'spec']),
name='fmt_cohort',
run_without_submitting=True,
iterfield=['template', 'spec'],
)

workflow.connect([
# Template concatenation
(inputnode, merge_anat2std, [('anat2std_xfm', 'in2')]),
(inputnode, merge_std2anat, [('std2anat_xfm', 'in2')]),
# Transform concatenation
(inputnode, dis_anat2int, [('anat2int_xfm', 'in_file')]),
(inputnode, dis_int2anat, [('int2anat_xfm', 'in_file')]),
(inputnode, intermed_xfms, [('intermediate', 'intermediate')]),
(inputnode, intermed_xfms, [('template', 'std')]),
(intermed_xfms, merge_anat2std, [('int2std_xfm', 'in1')]),
(intermed_xfms, merge_std2anat, [('std2int_xfm', 'in1')]),
(merge_anat2std, disassemble_anat2std, [('out', 'in_file')]),
(merge_std2anat, disassemble_std2anat, [('out', 'in_file')]),
(disassemble_anat2std, merge_anat2std_composites, [('out_transforms', 'in1')]),
(disassemble_std2anat, merge_std2anat_composites, [('out_transforms', 'in1')]),
(merge_anat2std_composites, assemble_anat2std, [('out', 'in_file')]),
(merge_std2anat_composites, assemble_std2anat, [('out', 'in_file')]),
(intermed_xfms, dis_int2std, [('int2std_xfm', 'in_file')]),
(intermed_xfms, dis_std2int, [('std2int_xfm', 'in_file')]),
(dis_anat2int, order_anat2std, [
('affine_transform', 'in1'),
('displacement_field', 'in2'),
]),
(dis_int2std, order_anat2std, [
('affine_transform', 'in3'),
('displacement_field', 'in4'),
]),
# Because std2anat are inverse transforms, warp is first
(dis_std2int, order_std2anat, [
('affine_transform', 'in2'),
('displacement_field', 'in1'),
]),
(dis_int2anat, order_std2anat, [
('affine_transform', 'in4'),
('displacement_field', 'in3'),
]),
(order_anat2std, assemble_anat2std, [('out', 'in_file')]),
(order_std2anat, assemble_std2anat, [('out', 'in_file')]),
(assemble_anat2std, outputnode, [('out_file', 'anat2std_xfm')]),
(assemble_std2anat, outputnode, [('out_file', 'std2anat_xfm')]),

Expand All @@ -483,6 +505,7 @@ def init_concat_registrations_wf(

def _load_intermediate_xfms(intermediate, std):
import json
from pathlib import Path

import pooch

Expand All @@ -496,6 +519,7 @@ def _load_intermediate_xfms(intermediate, std):
int2std_meta = xfms[int2std_name]
int2std = pooch.retrieve(
url=int2std_meta['url'],
path=Path.cwd(),
known_hash=int2std_meta['hash'],
fname=int2std_name,
)
Expand All @@ -504,8 +528,44 @@ def _load_intermediate_xfms(intermediate, std):
std2int_meta = xfms[std2int_name]
std2int = pooch.retrieve(
url=std2int_meta['url'],
path=Path.cwd(),
known_hash=std2int_meta['hash'],
fname=std2int_name,
)

return int2std, std2int


def _create_inverse_composite(in_file, out_file='inverse_composite.h5'):
"""Build a composite transform with SimpleITK.

This serves as a workaround for a bug in ANTs's CompositeTransformUtil
where composite transforms cannot be created with a displacement field placed first.

Parameters
----------
in_file : list of str
List of input transforms to concatenate into a composite transform.
out_file : str, optional
File to write the composite transform to.

Returns
-------
out_file : str
Absolute path to the composite transform.
from pathlib import Path

import SimpleITK as sitk

composite = sitk.CompositeTransform(3)
for xfm_file in in_file:
if xfm_file.endswith('mat'):
xfm = sitk.ReadTransform(xfm_file)
else:
xfm = sitk.DisplacementFieldTransform(sitk.ReadImage(xfm_file))

composite.AddTransform(xfm)

out_file = str(Path(out_file).absolute())
sitk.WriteTransform(composite, out_file)
return out_file
7 changes: 7 additions & 0 deletions nibabies/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,13 @@ def init_workflow_spaces(execution_spaces: SpatialReferences, age_months: int):
if not spaces.is_cached():
spaces.checkpoint()

# Ensure one cohort of MNIInfant is always available as an internal space
if not any(
space.startswith('MNIInfant') for space in spaces.get_spaces(nonstandard=False, dim=(3,))
):
cohort = cohort_by_months('MNIInfant', age_months)
spaces.add(Reference('MNIInfant', {'cohort': cohort}))

if config.workflow.cifti_output:
# CIFTI grayordinates to corresponding FSL-MNI resolutions.
vol_res = '2' if config.workflow.cifti_output == '91k' else '1'
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"psutil >= 5.4",
"pybids >= 0.15.0",
"requests",
"SimpleITK",
"sdcflows >= 2.10.0",
"smriprep >= 0.17.0",
"tedana >= 23.0.2",
Expand Down
Loading