Skip to content

Commit 57222fd

Browse files
committed
fix_rebase
1 parent 79b7b50 commit 57222fd

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

nitransforms/surface.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from nitransforms.base import (
1919
SurfaceMesh
2020
)
21+
import nibabel as nb
22+
from scipy.spatial import KDTree
23+
from scipy.spatial.distance import cdist
2124

2225

2326
class SurfaceTransformBase():
@@ -113,6 +116,7 @@ def __add__(self, other):
113116
return self.__class__(self.reference, other.moving)
114117
raise NotImplementedError
115118

119+
116120
def _to_hdf5(self, x5_root):
117121
"""Write transform to HDF5 file."""
118122
triangles = x5_root.create_group("Triangles")
@@ -211,6 +215,7 @@ def __init__(self, reference, moving, interpolation_method='barycentric', mat=No
211215
interpolation_method : str
212216
Only barycentric is currently implemented
213217
"""
218+
214219
super().__init__(SurfaceMesh(reference), SurfaceMesh(moving), spherical=True)
215220

216221
self.reference.set_radius()
@@ -226,6 +231,39 @@ def __init__(self, reference, moving, interpolation_method='barycentric', mat=No
226231
# transform
227232
if mat is None:
228233
self.__calculate_mat()
234+
r_tree = KDTree(self.reference._coords)
235+
m_tree = KDTree(self.moving._coords)
236+
kmr_dists, kmr_closest = m_tree.query(self.reference._coords, k=10)
237+
238+
# invert the triangles to generate a lookup table from vertices to triangle index
239+
tri_lut = dict()
240+
for i, idxs in enumerate(self.moving._triangles):
241+
for x in idxs:
242+
if not x in tri_lut:
243+
tri_lut[x] = [i]
244+
else:
245+
tri_lut[x].append(i)
246+
247+
# calculate the barycentric interpolation weights
248+
bc_weights = []
249+
enclosing = []
250+
for sidx, (point, kmrv) in enumerate(zip(self.reference._coords, kmr_closest)):
251+
close_tris = _find_close_tris(kmrv, tri_lut, self.moving)
252+
ww, ee = _find_weights(point, close_tris, m_tree)
253+
bc_weights.append(ww)
254+
enclosing.append(ee)
255+
256+
# build sparse matrix
257+
# commenting out code for barycentric nearest neighbor
258+
#bary_nearest = []
259+
mat = sparse.lil_array((self.reference._npoints, self.moving._npoints))
260+
for s_ix, dd in enumerate(bc_weights):
261+
for k, v in dd.items():
262+
mat[s_ix, k] = v
263+
# bary_nearest.append(np.array(list(dd.keys()))[np.array(list(dd.values())).argmax()])
264+
# bary_nearest = np.array(bary_nearest)
265+
# transpose so that number of out vertices is columns
266+
self.mat = sparse.csr_array(mat.T)
229267
else:
230268
if isinstance(mat, sparse.csr_array):
231269
self.mat = mat
@@ -283,7 +321,6 @@ def map(self, x):
283321
return x
284322

285323
def __add__(self, other):
286-
287324
if (isinstance(other, SurfaceResampler)
288325
and (other.interpolation_method == self.interpolation_method)):
289326
return self.__class__(
@@ -455,6 +492,7 @@ def from_filename(cls, filename=None, reference_path=None, moving_path=None,
455492

456493

457494
def _points_to_triangles(points, triangles):
495+
458496
"""Implementation that vectorizes project of a point to a set of triangles.
459497
from: https://stackoverflow.com/a/32529589
460498
"""
@@ -495,6 +533,7 @@ def _points_to_triangles(points, triangles):
495533
m2 = v < 0
496534
m3 = d < 0
497535
m4 = a + d > b + e
536+
498537
m5 = ce > bd
499538

500539
t0 = m0 & m1 & m2 & m3
@@ -588,6 +627,7 @@ def _find_close_tris(kdsv, tri_lut, surface):
588627
def _find_weights(point, close_tris, d_tree):
589628
point = point[np.newaxis, :]
590629
tri_dists = cdist(point, _points_to_triangles(point, close_tris).squeeze())
630+
591631
closest_tri = close_tris[(tri_dists == tri_dists.min()).squeeze()]
592632
# make sure a single closest triangle was found
593633
if closest_tri.shape[0] != 1:
@@ -599,6 +639,7 @@ def _find_weights(point, close_tris, d_tree):
599639
# Make sure point is actually inside triangle
600640
enclosing = True
601641
if np.all((point > closest_tri).sum(0) != 3):
642+
602643
enclosing = False
603644
_, ct_idxs = d_tree.query(closest_tri)
604645
a = closest_tri[0]

nitransforms/tests/test_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import h5py
66

7+
78
from ..base import (
89
SpatialReference,
910
SampledSpatialData,

nitransforms/tests/test_surface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
SurfaceResampler
1313
)
1414

15+
from nitransforms.base import SurfaceMesh
16+
from nitransforms.surface import SurfaceCoordinateTransform, SurfaceResampler
17+
18+
1519
# def test_surface_transform_npz():
1620
# mat = sparse.random(10, 10, density=0.5)
1721
# xfm = SurfaceCoordinateTransform(mat)
@@ -42,6 +46,7 @@
4246
# y_none = xfm.apply(x, normalize="none")
4347
# assert y_none.sum() != y_element.sum()
4448
# assert y_none.sum() != y_sum.sum()
49+
4550
def test_SurfaceTransformBase(testdata_path):
4651
# note these transformations are a bit of a weird use of surface transformation, but I'm
4752
# just testing the base class and the io
@@ -205,3 +210,4 @@ def test_SurfaceResampler(testdata_path, tmpdir):
205210
assert resampling3 == resampling
206211
resampled_thickness3 = resampling3.apply(subj_thickness.agg_data(), normalize='element')
207212
assert np.all(resampled_thickness3 == resampled_thickness)
213+

0 commit comments

Comments
 (0)