44import os
55from collections .abc import Callable
66from functools import partial
7- from pathlib import Path
87from typing import TypeVar
98
10- import h5py
119import nibabel as nb
1210import nitransforms as nt
1311import numpy as np
1917 traits ,
2018)
2119from nipype .utils .filemanip import fname_presuffix
22- from nitransforms .io .itk import ITKCompositeH5
2320from scipy import ndimage as ndi
2421from scipy .sparse import hstack as sparse_hstack
2522from sdcflows .transform import grid_bspline_weights
2623from sdcflows .utils .tools import ensure_positive_cosines
2724
25+ from nibabies .utils .transforms import load_transforms
26+
2827R = TypeVar ('R' )
2928
3029
@@ -34,95 +33,6 @@ async def worker(job: Callable[[], R], semaphore: asyncio.Semaphore) -> R:
3433 return await loop .run_in_executor (None , job )
3534
3635
37- def load_transforms (xfm_paths : list [Path ], inverse : list [bool ]) -> nt .base .TransformBase :
38- """Load a series of transforms as a nitransforms TransformChain
39-
40- An empty list will return an identity transform
41- """
42- if len (inverse ) == 1 :
43- inverse *= len (xfm_paths )
44- elif len (inverse ) != len (xfm_paths ):
45- raise ValueError ('Mismatched number of transforms and inverses' )
46-
47- chain = None
48- for path , inv in zip (xfm_paths [::- 1 ], inverse [::- 1 ], strict = False ):
49- path = Path (path )
50- if path .suffix == '.h5' :
51- xfm = load_ants_h5 (path )
52- else :
53- xfm = nt .linear .load (path )
54- if inv :
55- xfm = ~ xfm
56- if chain is None :
57- chain = xfm
58- else :
59- chain += xfm
60- if chain is None :
61- chain = nt .base .TransformBase ()
62- return chain
63-
64-
65- FIXED_PARAMS = np .array ([
66- 193.0 , 229.0 , 193.0 , # Size
67- 96.0 , 132.0 , - 78.0 , # Origin
68- 1.0 , 1.0 , 1.0 , # Spacing
69- - 1.0 , 0.0 , 0.0 , # Directions
70- 0.0 , - 1.0 , 0.0 ,
71- 0.0 , 0.0 , 1.0 ,
72- ]) # fmt:skip
73-
74-
75- def load_ants_h5 (filename : Path ) -> nt .base .TransformBase :
76- """Load ANTs H5 files as a nitransforms TransformChain"""
77- # Borrowed from https://github.com/feilong/process
78- # process.resample.parse_combined_hdf5()
79- #
80- # Changes:
81- # * Tolerate a missing displacement field
82- # * Return the original affine without a round-trip
83- # * Always return a nitransforms TransformChain
84- #
85- # This should be upstreamed into nitransforms
86- h = h5py .File (filename )
87- xform = ITKCompositeH5 .from_h5obj (h )
88-
89- # nt.Affine
90- transforms = [nt .Affine (xform [0 ].to_ras ())]
91-
92- if '2' not in h ['TransformGroup' ]:
93- return transforms [0 ]
94-
95- transform2 = h ['TransformGroup' ]['2' ]
96-
97- # Confirm these transformations are applicable
98- if transform2 ['TransformType' ][:][0 ] not in (
99- b'DisplacementFieldTransform_float_3_3' ,
100- b'DisplacementFieldTransform_double_3_3' ,
101- ):
102- msg = 'Unknown transform type [2]\n '
103- for i in h ['TransformGroup' ].keys ():
104- msg += f'[{ i } ]: { h ["TransformGroup" ][i ]["TransformType" ][:][0 ]} \n '
105- raise ValueError (msg )
106-
107- fixed_params = transform2 ['TransformFixedParameters' ][:]
108- shape = tuple (fixed_params [:3 ].astype (int ))
109- # ITK stores warps in Fortran-order, where the vector components change fastest
110- # Nitransforms expects 3 volumes, not a volume of three-vectors, so transpose
111- warp = np .reshape (
112- transform2 ['TransformParameters' ],
113- (3 , * shape ),
114- order = 'F' ,
115- ).transpose (1 , 2 , 3 , 0 )
116-
117- warp_affine = np .eye (4 )
118- warp_affine [:3 , :3 ] = fixed_params [9 :].reshape ((3 , 3 ))
119- warp_affine [:3 , 3 ] = fixed_params [3 :6 ]
120- lps_to_ras = np .eye (4 ) * np .array ([- 1 , - 1 , 1 , 1 ])
121- warp_affine = lps_to_ras @ warp_affine
122- transforms .insert (0 , nt .DenseFieldTransform (nb .Nifti1Image (warp , warp_affine )))
123- return nt .TransformChain (transforms )
124-
125-
12636class ResampleSeriesInputSpec (TraitedSpec ):
12737 in_file = File (exists = True , mandatory = True , desc = '3D or 4D image file to resample' )
12838 ref_file = File (exists = True , mandatory = True , desc = 'File to resample in_file to' )
@@ -788,7 +698,7 @@ def reconstruct_fieldmap(
788698 )
789699
790700 if not direct :
791- fmap_img = transforms .apply (fmap_img , reference = target )
701+ fmap_img = nt .apply (transforms , fmap_img , reference = target )
792702
793703 fmap_img .header .set_intent ('estimate' , name = 'fieldmap Hz' )
794704 fmap_img .header .set_data_dtype ('float32' )
0 commit comments