Skip to content

Commit ae9cd3e

Browse files
committed
cleanup + comments
1 parent 9f5ea5d commit ae9cd3e

File tree

3 files changed

+37
-119
lines changed

3 files changed

+37
-119
lines changed

RVGP/dataclass.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
compute_laplacian,
1111
compute_connection_laplacian,
1212
compute_spectrum,
13-
project_to_local_frame,
13+
express_in_local_frame,
1414
project_to_manifold,
1515
manifold_dimension
1616
)
@@ -30,7 +30,8 @@ def __init__(self,
3030
n_eigenpairs=None):
3131

3232
print('Fit graph')
33-
G = manifold_graph(vertices,n_neighbors=n_neighbors)
33+
G = manifold_graph(vertices, n_neighbors=n_neighbors)
34+
3435
print('Fit tangent spaces')
3536
gauges, Sigma = tangent_frames(vertices, G, vertices.shape[1], n_neighbors*frac_geodesic_neighbours)
3637

@@ -93,10 +94,13 @@ def __init__(self,
9394
def random_vector_field(self, seed=0):
9495
"""Generate random vector field over manifold"""
9596

96-
np.random.seed(0)
97+
np.random.seed(seed)
9798

98-
vectors = np.random.uniform(size=(len(self.vertices), 3))-.5
99-
vectors = project_to_manifold(vectors, self.gauges[...,:2])
99+
vectors = np.random.uniform(size=(len(self.vertices),
100+
self.vertices.shape[1])
101+
)
102+
vectors -= .5
103+
vectors = project_to_manifold(vectors, self.gauges[...,:self.dim_man])
100104
vectors /= np.linalg.norm(vectors, axis=1, keepdims=True)
101105

102106
self.vectors = vectors
@@ -106,9 +110,9 @@ def smooth_vector_field(self, t=100):
106110
if hasattr(self, 'vectors'):
107111

108112
"""Smooth vector field over manifold"""
109-
vectors = project_to_local_frame(self.vectors, self.gauges)
113+
vectors = express_in_local_frame(self.vectors, self.gauges)
110114
vectors = vector_diffusion(vectors, t, L=self.L, Lc=self.Lc, method="matrix_exp")
111-
vectors = project_to_local_frame(vectors, self.gauges, reverse=True)
115+
vectors = express_in_local_frame(vectors, self.gauges, reverse=True)
112116

113117
self.vectors = vectors
114118
else:

RVGP/geometry.py

Lines changed: 26 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,20 @@
99

1010
from sklearn.metrics import pairwise_distances
1111
from sklearn.neighbors import kneighbors_graph
12-
from sklearn.neighbors import KDTree
13-
14-
15-
16-
def compute_laplacian(G, normalization=False):
17-
18-
if normalization:
19-
laplacian = sparse.csr_matrix(nx.normalized_laplacian_matrix(G), dtype=np.float64)
20-
else:
21-
laplacian = sparse.csr_matrix(nx.laplacian_matrix(G), dtype=np.float64)
22-
23-
return laplacian
2412

2513

2614
def compute_connection_laplacian(G, R, normalization=None):
27-
r"""Connection Laplacian
15+
"""Connection Laplacian
2816
2917
Args:
3018
data: Pytorch geometric data object.
3119
R (nxnxdxd): Connection matrices between all pairs of nodes. Default is None,
3220
in case of a global coordinate system.
33-
normalization: None, 'sym', 'rw'
21+
normalization: None, 'rw'
3422
1. None: No normalization
3523
:math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
3624
37-
2. "sym"`: Symmetric normalization
38-
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
39-
\mathbf{D}^{-1/2}`
40-
41-
3. "rw"`: Random-walk normalization
25+
2. "rw"`: Random-walk normalization
4226
:math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
4327
4428
Returns:
@@ -65,12 +49,20 @@ def compute_connection_laplacian(G, R, normalization=None):
6549
deg_inv = deg_inv.repeat(dim, axis=0)
6650
Lc = sparse.diags(deg_inv, 0, format='csr') @ Lc
6751

68-
elif normalization == "sym":
69-
raise NotImplementedError
70-
7152
return Lc
7253

7354

55+
def compute_laplacian(G, normalization=False):
56+
"""Laplacian. Used as a helper function to compute_connection_laplacian()"""
57+
58+
if normalization:
59+
laplacian = sparse.csr_matrix(nx.normalized_laplacian_matrix(G), dtype=np.float64)
60+
else:
61+
laplacian = sparse.csr_matrix(nx.laplacian_matrix(G), dtype=np.float64)
62+
63+
return laplacian
64+
65+
7466
def compute_spectrum(laplacian, n_eigenpairs=None, dtype=tf.float64):
7567

7668
if n_eigenpairs is None:
@@ -88,32 +80,6 @@ def compute_spectrum(laplacian, n_eigenpairs=None, dtype=tf.float64):
8880
return evals, evecs
8981

9082

91-
def sample_from_convex_hull(points, num_samples, k=5):
92-
93-
tree = scipy.spatial.KDTree(points)
94-
95-
if num_samples > len(points):
96-
num_samples = len(points)
97-
98-
sample_points = np.random.choice(len(points), size=num_samples, replace=False)
99-
sample_points = points[sample_points]
100-
101-
# Generate samples
102-
samples = []
103-
for current_point in sample_points:
104-
_, nn_ind = tree.query(current_point, k=k, p=2)
105-
nn_hull = points[nn_ind]
106-
107-
barycentric_coords = np.random.uniform(size=nn_hull.shape[0])
108-
barycentric_coords /= np.sum(barycentric_coords)
109-
110-
current_point = np.sum(nn_hull.T * barycentric_coords, axis=1)
111-
112-
samples.append(current_point)
113-
114-
return np.array(samples)
115-
116-
11783
def manifold_dimension(Sigma, frac_explained=0.9):
11884
"""Estimate manifold dimension based on singular vectors"""
11985

@@ -134,7 +100,14 @@ def manifold_dimension(Sigma, frac_explained=0.9):
134100
def manifold_graph(X, typ = 'knn', n_neighbors=5):
135101
"""Fit graph over a pointset X"""
136102
if typ == 'knn':
137-
A = kneighbors_graph(X, n_neighbors, mode='connectivity', metric='minkowski', p=2, metric_params=None, include_self=False, n_jobs=None)
103+
A = kneighbors_graph(X,
104+
n_neighbors,
105+
mode='connectivity',
106+
metric='minkowski',
107+
p=2,
108+
metric_params=None,
109+
include_self=False,
110+
n_jobs=None)
138111
A += sparse.eye(A.shape[0])
139112
G = nx.from_scipy_sparse_array(A)
140113

@@ -150,53 +123,6 @@ def manifold_graph(X, typ = 'knn', n_neighbors=5):
150123
return G
151124

152125

153-
def find_nn(x_query, X, nn=3, r=None):
154-
"""
155-
Find nearest neighbors of a point on the manifold
156-
157-
Parameters
158-
----------
159-
ind_query : 2d np array, list[2d np array]
160-
Index of points whose neighbors are needed.
161-
x : nxd array (dimensions are columns!)
162-
Coordinates of n points on a manifold in d-dimensional space.
163-
nn : int, optional
164-
Number of nearest neighbors. The default is 1.
165-
166-
Returns
167-
-------
168-
dist : list[list]
169-
Distance of nearest neighbors.
170-
ind : list[list]
171-
Index of nearest neighbors.
172-
173-
"""
174-
175-
#Fit neighbor estimator object
176-
kdt = KDTree(X, leaf_size=30, metric='euclidean')
177-
178-
if r is not None:
179-
ind, dist = kdt.query_radius(x_query, r=r, return_distance=True, sort_results=True)
180-
ind = ind[0]
181-
dist = dist[0]
182-
else:
183-
# apparently, the outputs are reversed here compared to query_radius()
184-
dist, ind = kdt.query(x_query, k=nn)
185-
186-
return dist, ind.flatten()
187-
188-
189-
def closest_manifold_point(x_query, d, nn=3):
190-
dist, ind = find_nn(x_query, d.vertices, nn=nn)
191-
w = 1/(dist.T+0.00001)
192-
w /= w.sum()
193-
positional_encoding = d.evecs_Lc.reshape(d.n, -1)
194-
pe_manifold = (positional_encoding[ind]*w).sum(0, keepdims=True)
195-
x_manifold = d.vertices[ind]
196-
197-
return x_manifold, pe_manifold
198-
199-
200126
def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):
201127
"""A greedy O(N^2) algorithm to do furthest points sampling
202128
@@ -217,8 +143,6 @@ def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):
217143
n = D.shape[0] if N is None else N
218144
diam = D.max()
219145

220-
start_idx = 5
221-
222146
perm = np.zeros(n, dtype=np.int32)
223147
perm[0] = start_idx
224148
lambdas = np.zeros(n)
@@ -239,23 +163,14 @@ def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):
239163

240164

241165
def project_to_manifold(x, gauges):
166+
"""Project vectors to local coordinates over manifold"""
242167
coeffs = np.einsum("bij,bi->bj", gauges, x)
243168
return np.einsum("bj,bij->bi", coeffs, gauges)
244169

245170

246-
def project_to_local_frame(x, gauges, reverse=False):
171+
def express_in_local_frame(x, gauges, reverse=False):
172+
"""Express vectors in local coordinates over manifold"""
247173
if reverse:
248174
return np.einsum("bji,bi->bj", gauges, x)
249175
else:
250-
return np.einsum("bij,bi->bj", gauges, x)
251-
252-
253-
def local_to_global(x, gauges):
254-
return np.einsum("bj,bij->bi", x, gauges)
255-
256-
257-
def node_eigencoords(node_ind, evecs_Lc, dim):
258-
r, c = evecs_Lc.shape
259-
evecs_Lc = evecs_Lc.reshape(-1, c*dim)
260-
node_coords = evecs_Lc[node_ind]
261-
return node_coords.reshape(-1, c)
176+
return np.einsum("bij,bi->bj", gauges, x)

examples/eeg_example/run_eeg_vector_field_interp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# Load EEG vector field data
1717
# =============================================================================
1818

19-
2019
def find_mat_files(directory):
2120
mat_files = {}
2221
for root, dirs, files in os.walk(directory):

0 commit comments

Comments
 (0)