4
4
import os
5
5
from collections .abc import Callable
6
6
from functools import partial
7
- from pathlib import Path
8
7
from typing import TypeVar
9
8
10
- import h5py
11
9
import nibabel as nb
12
10
import nitransforms as nt
13
11
import numpy as np
19
17
traits ,
20
18
)
21
19
from nipype .utils .filemanip import fname_presuffix
22
- from nitransforms .io .itk import ITKCompositeH5
23
20
from scipy import ndimage as ndi
24
21
from scipy .sparse import hstack as sparse_hstack
25
22
from sdcflows .transform import grid_bspline_weights
26
23
from sdcflows .utils .tools import ensure_positive_cosines
27
24
25
+ from nibabies .utils .transforms import load_transforms
26
+
28
27
R = TypeVar ('R' )
29
28
30
29
@@ -34,95 +33,6 @@ async def worker(job: Callable[[], R], semaphore: asyncio.Semaphore) -> R:
34
33
return await loop .run_in_executor (None , job )
35
34
36
35
37
- def load_transforms (xfm_paths : list [Path ], inverse : list [bool ]) -> nt .base .TransformBase :
38
- """Load a series of transforms as a nitransforms TransformChain
39
-
40
- An empty list will return an identity transform
41
- """
42
- if len (inverse ) == 1 :
43
- inverse *= len (xfm_paths )
44
- elif len (inverse ) != len (xfm_paths ):
45
- raise ValueError ('Mismatched number of transforms and inverses' )
46
-
47
- chain = None
48
- for path , inv in zip (xfm_paths [::- 1 ], inverse [::- 1 ], strict = False ):
49
- path = Path (path )
50
- if path .suffix == '.h5' :
51
- xfm = load_ants_h5 (path )
52
- else :
53
- xfm = nt .linear .load (path )
54
- if inv :
55
- xfm = ~ xfm
56
- if chain is None :
57
- chain = xfm
58
- else :
59
- chain += xfm
60
- if chain is None :
61
- chain = nt .base .TransformBase ()
62
- return chain
63
-
64
-
65
- FIXED_PARAMS = np .array ([
66
- 193.0 , 229.0 , 193.0 , # Size
67
- 96.0 , 132.0 , - 78.0 , # Origin
68
- 1.0 , 1.0 , 1.0 , # Spacing
69
- - 1.0 , 0.0 , 0.0 , # Directions
70
- 0.0 , - 1.0 , 0.0 ,
71
- 0.0 , 0.0 , 1.0 ,
72
- ]) # fmt:skip
73
-
74
-
75
- def load_ants_h5 (filename : Path ) -> nt .base .TransformBase :
76
- """Load ANTs H5 files as a nitransforms TransformChain"""
77
- # Borrowed from https://github.com/feilong/process
78
- # process.resample.parse_combined_hdf5()
79
- #
80
- # Changes:
81
- # * Tolerate a missing displacement field
82
- # * Return the original affine without a round-trip
83
- # * Always return a nitransforms TransformChain
84
- #
85
- # This should be upstreamed into nitransforms
86
- h = h5py .File (filename )
87
- xform = ITKCompositeH5 .from_h5obj (h )
88
-
89
- # nt.Affine
90
- transforms = [nt .Affine (xform [0 ].to_ras ())]
91
-
92
- if '2' not in h ['TransformGroup' ]:
93
- return transforms [0 ]
94
-
95
- transform2 = h ['TransformGroup' ]['2' ]
96
-
97
- # Confirm these transformations are applicable
98
- if transform2 ['TransformType' ][:][0 ] not in (
99
- b'DisplacementFieldTransform_float_3_3' ,
100
- b'DisplacementFieldTransform_double_3_3' ,
101
- ):
102
- msg = 'Unknown transform type [2]\n '
103
- for i in h ['TransformGroup' ].keys ():
104
- msg += f'[{ i } ]: { h ["TransformGroup" ][i ]["TransformType" ][:][0 ]} \n '
105
- raise ValueError (msg )
106
-
107
- fixed_params = transform2 ['TransformFixedParameters' ][:]
108
- shape = tuple (fixed_params [:3 ].astype (int ))
109
- # ITK stores warps in Fortran-order, where the vector components change fastest
110
- # Nitransforms expects 3 volumes, not a volume of three-vectors, so transpose
111
- warp = np .reshape (
112
- transform2 ['TransformParameters' ],
113
- (3 , * shape ),
114
- order = 'F' ,
115
- ).transpose (1 , 2 , 3 , 0 )
116
-
117
- warp_affine = np .eye (4 )
118
- warp_affine [:3 , :3 ] = fixed_params [9 :].reshape ((3 , 3 ))
119
- warp_affine [:3 , 3 ] = fixed_params [3 :6 ]
120
- lps_to_ras = np .eye (4 ) * np .array ([- 1 , - 1 , 1 , 1 ])
121
- warp_affine = lps_to_ras @ warp_affine
122
- transforms .insert (0 , nt .DenseFieldTransform (nb .Nifti1Image (warp , warp_affine )))
123
- return nt .TransformChain (transforms )
124
-
125
-
126
36
class ResampleSeriesInputSpec (TraitedSpec ):
127
37
in_file = File (exists = True , mandatory = True , desc = '3D or 4D image file to resample' )
128
38
ref_file = File (exists = True , mandatory = True , desc = 'File to resample in_file to' )
@@ -788,7 +698,7 @@ def reconstruct_fieldmap(
788
698
)
789
699
790
700
if not direct :
791
- fmap_img = transforms .apply (fmap_img , reference = target )
701
+ fmap_img = nt .apply (transforms , fmap_img , reference = target )
792
702
793
703
fmap_img .header .set_intent ('estimate' , name = 'fieldmap Hz' )
794
704
fmap_img .header .set_data_dtype ('float32' )
0 commit comments