@@ -65,50 +65,47 @@ def __init__(self, field=None, is_deltas=True, reference=None):
65
65
<DenseFieldTransform[3D] (57, 67, 56)>
66
66
67
67
"""
68
+
68
69
if field is None and reference is None :
69
- raise TransformError ("DenseFieldTransforms require a spatial reference " )
70
+ raise TransformError ("cannot initialize field " )
70
71
71
72
super ().__init__ ()
72
73
73
- self ._is_deltas = is_deltas
74
+ if field is not None :
75
+ field = _ensure_image (field )
76
+ # Extract data if nibabel object otherwise assume numpy array
77
+ _data = np .squeeze (
78
+ np .asanyarray (field .dataobj )
79
+ if hasattr (field , "dataobj" )
80
+ else field .copy ()
81
+ )
74
82
75
83
try :
76
84
self .reference = ImageGrid (reference if reference is not None else field )
77
85
except AttributeError :
78
86
raise TransformError (
79
- "Field must be a spatial image if reference is not provided"
87
+ "field must be a spatial image if reference is not provided"
80
88
if reference is None
81
- else "Reference is not a spatial image"
89
+ else "reference is not a spatial image"
82
90
)
83
91
84
92
fieldshape = (* self .reference .shape , self .reference .ndim )
85
- if field is not None :
86
- field = _ensure_image (field )
87
- self ._field = np .squeeze (
88
- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
89
- )
90
- if fieldshape != self ._field .shape :
91
- raise TransformError (
92
- f"Shape of the field ({ 'x' .join (str (i ) for i in self ._field .shape )} ) "
93
- f"doesn't match that of the reference({ 'x' .join (str (i ) for i in fieldshape )} )"
94
- )
95
- else :
96
- self ._field = np .zeros (fieldshape , dtype = "float32" )
97
- self ._is_deltas = True
98
-
99
- if self ._field .shape [- 1 ] != self .ndim :
93
+ if field is None :
94
+ _data = np .zeros (fieldshape )
95
+ elif fieldshape != _data .shape :
100
96
raise TransformError (
101
- "The number of components of the field (%d) does not match "
102
- "the number of dimensions (%d)" % ( self . _field . shape [ - 1 ], self . ndim )
97
+ f"Shape of the field ({ 'x' . join ( str ( i ) for i in _data . shape ) } ) "
98
+ f"doesn't match that of the reference( { 'x' . join ( str ( i ) for i in fieldshape ) } )"
103
99
)
104
100
101
+ self ._is_deltas = is_deltas
102
+ self ._field = self .reference .ndcoords .reshape (fieldshape )
103
+
105
104
if self .is_deltas :
106
- self ._deltas = (
107
- self ._field .copy ()
108
- ) # IMPORTANT: you don't want to update deltas
109
- # Convert from displacements (deltas) to deformations fields
110
- # (just add its origin to each delta vector)
111
- self ._field += self .reference .ndcoords .T .reshape (fieldshape )
105
+ self ._deltas = _data .copy ()
106
+ self ._field += self ._deltas
107
+ else :
108
+ self ._field = _data .copy ()
112
109
113
110
def __repr__ (self ):
114
111
"""Beautify the python representation."""
@@ -153,7 +150,7 @@ def map(self, x, inverse=False):
153
150
... test_dir / "someones_displacement_field.nii.gz",
154
151
... is_deltas=False,
155
152
... )
156
- >>> xfm.map([-6.5, -36., -19.5]).tolist()
153
+ >>> xfm.map([[ -6.5, -36., -19.5] ]).tolist()
157
154
[[0.0, -0.47516798973083496, 0.0]]
158
155
159
156
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
@@ -170,8 +167,8 @@ def map(self, x, inverse=False):
170
167
... test_dir / "someones_displacement_field.nii.gz",
171
168
... is_deltas=True,
172
169
... )
173
- >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
174
- [[-6.5, -36.47516632080078 , -19.5], [-1.0, -42.03835678100586 , -11.25]]
170
+ >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() # doctest: +ELLIPSIS
171
+ [[-6.5, -36.475... , -19.5], [-1.0, -42.038... , -11.25]]
175
172
176
173
>>> np.array_str(
177
174
... xfm.map([[-6.7, -36.3, -19.2], [-1., -41.5, -11.25]]),
@@ -185,18 +182,19 @@ def map(self, x, inverse=False):
185
182
if inverse is True :
186
183
raise NotImplementedError
187
184
188
- ijk = self .reference .index (x )
185
+ ijk = self .reference .index (np . array ( x , dtype = "float32" ) )
189
186
indexes = np .round (ijk ).astype ("int" )
187
+ ongrid = np .where (np .linalg .norm (ijk - indexes , axis = 1 ) < 1e-3 )[0 ]
190
188
191
- if np . all ( np .abs ( ijk - indexes ) < 1e-3 ) :
192
- indexes = tuple ( tuple ( i ) for i in indexes )
193
- return self ._field [indexes ]
189
+ if ongrid . size == np .shape ( x )[ 0 ] :
190
+ # return self._field[* indexes.T, :] # From Python 3.11
191
+ return self ._field [tuple ( indexes . T ) + ( np . s_ [:],) ]
194
192
195
- new_map = np .vstack (
193
+ mapped_coords = np .vstack (
196
194
tuple (
197
195
map_coordinates (
198
196
self ._field [..., i ],
199
- ijk ,
197
+ ijk . T ,
200
198
order = 3 ,
201
199
mode = "constant" ,
202
200
cval = np .nan ,
@@ -207,8 +205,8 @@ def map(self, x, inverse=False):
207
205
).T
208
206
209
207
# Set NaN values back to the original coordinates value = no displacement
210
- new_map [np .isnan (new_map )] = np .array (x )[np .isnan (new_map )]
211
- return new_map
208
+ mapped_coords [np .isnan (mapped_coords )] = np .array (x )[np .isnan (mapped_coords )]
209
+ return mapped_coords
212
210
213
211
def __matmul__ (self , b ):
214
212
"""
0 commit comments