18
18
from nitransforms .base import (
19
19
SurfaceMesh
20
20
)
21
+ import nibabel as nb
22
+ from scipy .spatial import KDTree
23
+ from scipy .spatial .distance import cdist
21
24
22
25
23
26
class SurfaceTransformBase ():
@@ -113,6 +116,7 @@ def __add__(self, other):
113
116
return self .__class__ (self .reference , other .moving )
114
117
raise NotImplementedError
115
118
119
+
116
120
def _to_hdf5 (self , x5_root ):
117
121
"""Write transform to HDF5 file."""
118
122
triangles = x5_root .create_group ("Triangles" )
@@ -211,6 +215,7 @@ def __init__(self, reference, moving, interpolation_method='barycentric', mat=No
211
215
interpolation_method : str
212
216
Only barycentric is currently implemented
213
217
"""
218
+
214
219
super ().__init__ (SurfaceMesh (reference ), SurfaceMesh (moving ), spherical = True )
215
220
216
221
self .reference .set_radius ()
@@ -226,6 +231,39 @@ def __init__(self, reference, moving, interpolation_method='barycentric', mat=No
226
231
# transform
227
232
if mat is None :
228
233
self .__calculate_mat ()
234
+ r_tree = KDTree (self .reference ._coords )
235
+ m_tree = KDTree (self .moving ._coords )
236
+ kmr_dists , kmr_closest = m_tree .query (self .reference ._coords , k = 10 )
237
+
238
+ # invert the triangles to generate a lookup table from vertices to triangle index
239
+ tri_lut = dict ()
240
+ for i , idxs in enumerate (self .moving ._triangles ):
241
+ for x in idxs :
242
+ if not x in tri_lut :
243
+ tri_lut [x ] = [i ]
244
+ else :
245
+ tri_lut [x ].append (i )
246
+
247
+ # calculate the barycentric interpolation weights
248
+ bc_weights = []
249
+ enclosing = []
250
+ for sidx , (point , kmrv ) in enumerate (zip (self .reference ._coords , kmr_closest )):
251
+ close_tris = _find_close_tris (kmrv , tri_lut , self .moving )
252
+ ww , ee = _find_weights (point , close_tris , m_tree )
253
+ bc_weights .append (ww )
254
+ enclosing .append (ee )
255
+
256
+ # build sparse matrix
257
+ # commenting out code for barycentric nearest neighbor
258
+ #bary_nearest = []
259
+ mat = sparse .lil_array ((self .reference ._npoints , self .moving ._npoints ))
260
+ for s_ix , dd in enumerate (bc_weights ):
261
+ for k , v in dd .items ():
262
+ mat [s_ix , k ] = v
263
+ # bary_nearest.append(np.array(list(dd.keys()))[np.array(list(dd.values())).argmax()])
264
+ # bary_nearest = np.array(bary_nearest)
265
+ # transpose so that number of out vertices is columns
266
+ self .mat = sparse .csr_array (mat .T )
229
267
else :
230
268
if isinstance (mat , sparse .csr_array ):
231
269
self .mat = mat
@@ -283,7 +321,6 @@ def map(self, x):
283
321
return x
284
322
285
323
def __add__ (self , other ):
286
-
287
324
if (isinstance (other , SurfaceResampler )
288
325
and (other .interpolation_method == self .interpolation_method )):
289
326
return self .__class__ (
@@ -455,6 +492,7 @@ def from_filename(cls, filename=None, reference_path=None, moving_path=None,
455
492
456
493
457
494
def _points_to_triangles (points , triangles ):
495
+
458
496
"""Implementation that vectorizes project of a point to a set of triangles.
459
497
from: https://stackoverflow.com/a/32529589
460
498
"""
@@ -495,6 +533,7 @@ def _points_to_triangles(points, triangles):
495
533
m2 = v < 0
496
534
m3 = d < 0
497
535
m4 = a + d > b + e
536
+
498
537
m5 = ce > bd
499
538
500
539
t0 = m0 & m1 & m2 & m3
@@ -588,6 +627,7 @@ def _find_close_tris(kdsv, tri_lut, surface):
588
627
def _find_weights (point , close_tris , d_tree ):
589
628
point = point [np .newaxis , :]
590
629
tri_dists = cdist (point , _points_to_triangles (point , close_tris ).squeeze ())
630
+
591
631
closest_tri = close_tris [(tri_dists == tri_dists .min ()).squeeze ()]
592
632
# make sure a single closest triangle was found
593
633
if closest_tri .shape [0 ] != 1 :
@@ -599,6 +639,7 @@ def _find_weights(point, close_tris, d_tree):
599
639
# Make sure point is actually inside triangle
600
640
enclosing = True
601
641
if np .all ((point > closest_tri ).sum (0 ) != 3 ):
642
+
602
643
enclosing = False
603
644
_ , ct_idxs = d_tree .query (closest_tri )
604
645
a = closest_tri [0 ]
0 commit comments