Skip to content

Commit e8fa8ad

Browse files
authored
Merge pull request #421 from mgxd/rf/nitransforms
FIX: Use nitransforms for most xfm handling
2 parents b7e3fcc + d338dd0 commit e8fa8ad

File tree

5 files changed

+56
-124
lines changed

5 files changed

+56
-124
lines changed

nibabies/interfaces/resampling.py

Lines changed: 3 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
import os
55
from collections.abc import Callable
66
from functools import partial
7-
from pathlib import Path
87
from typing import TypeVar
98

10-
import h5py
119
import nibabel as nb
1210
import nitransforms as nt
1311
import numpy as np
@@ -19,12 +17,13 @@
1917
traits,
2018
)
2119
from nipype.utils.filemanip import fname_presuffix
22-
from nitransforms.io.itk import ITKCompositeH5
2320
from scipy import ndimage as ndi
2421
from scipy.sparse import hstack as sparse_hstack
2522
from sdcflows.transform import grid_bspline_weights
2623
from sdcflows.utils.tools import ensure_positive_cosines
2724

25+
from nibabies.utils.transforms import load_transforms
26+
2827
R = TypeVar('R')
2928

3029

@@ -34,95 +33,6 @@ async def worker(job: Callable[[], R], semaphore: asyncio.Semaphore) -> R:
3433
return await loop.run_in_executor(None, job)
3534

3635

37-
def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase:
38-
"""Load a series of transforms as a nitransforms TransformChain
39-
40-
An empty list will return an identity transform
41-
"""
42-
if len(inverse) == 1:
43-
inverse *= len(xfm_paths)
44-
elif len(inverse) != len(xfm_paths):
45-
raise ValueError('Mismatched number of transforms and inverses')
46-
47-
chain = None
48-
for path, inv in zip(xfm_paths[::-1], inverse[::-1], strict=False):
49-
path = Path(path)
50-
if path.suffix == '.h5':
51-
xfm = load_ants_h5(path)
52-
else:
53-
xfm = nt.linear.load(path)
54-
if inv:
55-
xfm = ~xfm
56-
if chain is None:
57-
chain = xfm
58-
else:
59-
chain += xfm
60-
if chain is None:
61-
chain = nt.base.TransformBase()
62-
return chain
63-
64-
65-
FIXED_PARAMS = np.array([
66-
193.0, 229.0, 193.0, # Size
67-
96.0, 132.0, -78.0, # Origin
68-
1.0, 1.0, 1.0, # Spacing
69-
-1.0, 0.0, 0.0, # Directions
70-
0.0, -1.0, 0.0,
71-
0.0, 0.0, 1.0,
72-
]) # fmt:skip
73-
74-
75-
def load_ants_h5(filename: Path) -> nt.base.TransformBase:
76-
"""Load ANTs H5 files as a nitransforms TransformChain"""
77-
# Borrowed from https://github.com/feilong/process
78-
# process.resample.parse_combined_hdf5()
79-
#
80-
# Changes:
81-
# * Tolerate a missing displacement field
82-
# * Return the original affine without a round-trip
83-
# * Always return a nitransforms TransformChain
84-
#
85-
# This should be upstreamed into nitransforms
86-
h = h5py.File(filename)
87-
xform = ITKCompositeH5.from_h5obj(h)
88-
89-
# nt.Affine
90-
transforms = [nt.Affine(xform[0].to_ras())]
91-
92-
if '2' not in h['TransformGroup']:
93-
return transforms[0]
94-
95-
transform2 = h['TransformGroup']['2']
96-
97-
# Confirm these transformations are applicable
98-
if transform2['TransformType'][:][0] not in (
99-
b'DisplacementFieldTransform_float_3_3',
100-
b'DisplacementFieldTransform_double_3_3',
101-
):
102-
msg = 'Unknown transform type [2]\n'
103-
for i in h['TransformGroup'].keys():
104-
msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n'
105-
raise ValueError(msg)
106-
107-
fixed_params = transform2['TransformFixedParameters'][:]
108-
shape = tuple(fixed_params[:3].astype(int))
109-
# ITK stores warps in Fortran-order, where the vector components change fastest
110-
# Nitransforms expects 3 volumes, not a volume of three-vectors, so transpose
111-
warp = np.reshape(
112-
transform2['TransformParameters'],
113-
(3, *shape),
114-
order='F',
115-
).transpose(1, 2, 3, 0)
116-
117-
warp_affine = np.eye(4)
118-
warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3))
119-
warp_affine[:3, 3] = fixed_params[3:6]
120-
lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1])
121-
warp_affine = lps_to_ras @ warp_affine
122-
transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)))
123-
return nt.TransformChain(transforms)
124-
125-
12636
class ResampleSeriesInputSpec(TraitedSpec):
12737
in_file = File(exists=True, mandatory=True, desc='3D or 4D image file to resample')
12838
ref_file = File(exists=True, mandatory=True, desc='File to resample in_file to')
@@ -788,7 +698,7 @@ def reconstruct_fieldmap(
788698
)
789699

790700
if not direct:
791-
fmap_img = transforms.apply(fmap_img, reference=target)
701+
fmap_img = nt.apply(transforms, fmap_img, reference=target)
792702

793703
fmap_img.header.set_intent('estimate', name='fieldmap Hz')
794704
fmap_img.header.set_data_dtype('float32')

nibabies/utils/transforms.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Utilities for loading transforms for resampling"""
2+
3+
from pathlib import Path
4+
5+
import nitransforms as nt
6+
7+
8+
def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase:
9+
"""Load a series of transforms as a nitransforms TransformChain
10+
11+
An empty list will return an identity transform
12+
"""
13+
if len(inverse) == 1:
14+
inverse *= len(xfm_paths)
15+
elif len(inverse) != len(xfm_paths):
16+
raise ValueError('Mismatched number of transforms and inverses')
17+
18+
chain = None
19+
for path, inv in zip(xfm_paths[::-1], inverse[::-1], strict=False):
20+
path = Path(path)
21+
if path.suffix == '.h5':
22+
# Load as a TransformChain
23+
xfm = nt.manip.load(path)
24+
else:
25+
xfm = nt.linear.load(path)
26+
if inv:
27+
xfm = ~xfm
28+
if chain is None:
29+
chain = xfm
30+
else:
31+
chain += xfm
32+
if chain is None:
33+
chain = nt.Affine() # Identity
34+
return chain

nibabies/workflows/bold/registration.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -704,11 +704,11 @@ def compare_xforms(lta_list, norm_threshold=15):
704704
second transform relative to the first (default: `15`)
705705
706706
"""
707+
import nitransforms as nt
707708
from nipype.algorithms.rapidart import _calc_norm_affine
708-
from niworkflows.interfaces.surf import load_transform
709709

710-
bbr_affine = load_transform(lta_list[0])
711-
fallback_affine = load_transform(lta_list[1])
710+
bbr_affine = nt.linear.load(lta_list[0]).matrix
711+
fallback_affine = nt.linear.load(lta_list[1]).matrix
712712

713713
norm, _ = _calc_norm_affine([fallback_affine, bbr_affine], use_differences=True)
714714

@@ -741,14 +741,16 @@ def _conditional_downsampling(in_file, in_mask, zoom_th=4.0):
741741
offset = old_center - newrot.dot((newshape - 1) * 0.5)
742742
newaffine = nb.affines.from_matvec(newrot, offset)
743743

744+
identity = nt.Affine()
745+
744746
newref = nb.Nifti1Image(np.zeros(newshape, dtype=np.uint8), newaffine)
745-
nt.Affine(reference=newref).apply(img).to_filename(out_file)
747+
nt.apply(identity, img, reference=newref).to_filename(out_file)
746748

747749
mask = nb.load(in_mask)
748750
mask.set_data_dtype(float)
749751
mdata = gaussian_filter(mask.get_fdata(dtype=float), scaling)
750752
floatmask = nb.Nifti1Image(mdata, mask.affine, mask.header)
751-
newmask = nt.Affine(reference=newref).apply(floatmask)
753+
newmask = nt.apply(identity, floatmask, reference=newref)
752754
hdr = newmask.header.copy()
753755
hdr.set_data_dtype(np.uint8)
754756
newmaskdata = (newmask.get_fdata(dtype=float) > 0.5).astype(np.uint8)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies = [
2424
"nipype >= 1.8.5",
2525
"nireports >= 23.2.0",
2626
"nitime",
27-
"nitransforms >= 23.0.1",
27+
"nitransforms >= 24.1.1",
2828
"niworkflows >= 1.12.1",
2929
"numpy >= 1.21.0",
3030
"packaging",

requirements.txt

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ annexremote==1.6.6
1717
# datalad-osf
1818
astor==0.8.1
1919
# via formulaic
20-
attrs==24.2.0
20+
attrs==24.3.0
2121
# via
2222
# jsonschema
2323
# niworkflows
@@ -31,16 +31,14 @@ bidsschematools==1.0.0
3131
# via bids-validator
3232
bokeh==3.5.2
3333
# via tedana
34-
boto3==1.35.80
34+
boto3==1.35.83
3535
# via datalad
36-
botocore==1.35.80
36+
botocore==1.35.83
3737
# via
3838
# boto3
3939
# s3transfer
40-
certifi==2024.8.30
40+
certifi==2024.12.14
4141
# via requests
42-
cffi==1.17.1
43-
# via cryptography
4442
chardet==5.2.0
4543
# via datalad
4644
charset-normalizer==3.4.0
@@ -58,11 +56,9 @@ contourpy==1.3.1
5856
# via
5957
# bokeh
6058
# matplotlib
61-
cryptography==44.0.0
62-
# via secretstorage
6359
cycler==0.12.1
6460
# via matplotlib
65-
datalad==1.1.4
61+
datalad==1.1.5
6662
# via
6763
# datalad-next
6864
# datalad-osf
@@ -87,8 +83,6 @@ formulaic==0.5.2
8783
# via pybids
8884
fsspec==2024.10.0
8985
# via universal-pathlib
90-
greenlet==3.1.1
91-
# via sqlalchemy
9286
h5py==3.12.1
9387
# via nitransforms
9488
humanize==4.11.0
@@ -123,10 +117,6 @@ jaraco-context==6.0.1
123117
# keyrings-alt
124118
jaraco-functools==4.1.0
125119
# via keyring
126-
jeepney==0.8.0
127-
# via
128-
# keyring
129-
# secretstorage
130120
jinja2==3.1.4
131121
# via
132122
# bokeh
@@ -171,7 +161,7 @@ mapca==0.0.5
171161
# via tedana
172162
markupsafe==3.0.2
173163
# via jinja2
174-
matplotlib==3.9.3
164+
matplotlib==3.10.0
175165
# via
176166
# nireports
177167
# nitime
@@ -214,7 +204,7 @@ nilearn==0.10.4
214204
# nireports
215205
# niworkflows
216206
# tedana
217-
nipype==1.9.1
207+
nipype==1.9.2
218208
# via
219209
# nibabies (pyproject.toml)
220210
# nireports
@@ -225,7 +215,7 @@ nireports==24.0.3
225215
# via nibabies (pyproject.toml)
226216
nitime==0.11
227217
# via nibabies (pyproject.toml)
228-
nitransforms==24.1.0
218+
nitransforms==24.1.1
229219
# via
230220
# nibabies (pyproject.toml)
231221
# niworkflows
@@ -235,7 +225,7 @@ niworkflows==1.12.1
235225
# nibabies (pyproject.toml)
236226
# sdcflows
237227
# smriprep
238-
num2words==0.5.13
228+
num2words==0.5.14
239229
# via pybids
240230
numpy==2.1.1
241231
# via
@@ -327,8 +317,6 @@ pybtex==0.24.0
327317
# via tedana
328318
pybtex-apa-style==1.3
329319
# via tedana
330-
pycparser==2.22
331-
# via cffi
332320
pydot==3.0.3
333321
# via nipype
334322
pyparsing==3.2.0
@@ -343,7 +331,7 @@ python-dateutil==2.9.0.post0
343331
# nipype
344332
# pandas
345333
# prov
346-
python-gitlab==5.1.0
334+
python-gitlab==5.2.0
347335
# via datalad
348336
pytz==2024.2
349337
# via pandas
@@ -384,7 +372,7 @@ rpds-py==0.22.3
384372
# referencing
385373
s3transfer==0.10.4
386374
# via boto3
387-
scikit-image==0.24.0
375+
scikit-image==0.25.0
388376
# via
389377
# niworkflows
390378
# sdcflows
@@ -415,8 +403,6 @@ seaborn==0.13.2
415403
# via
416404
# nireports
417405
# niworkflows
418-
secretstorage==3.3.3
419-
# via keyring
420406
simplejson==3.19.3
421407
# via nipype
422408
six==1.17.0

0 commit comments

Comments
 (0)