99
1010from sklearn .metrics import pairwise_distances
1111from sklearn .neighbors import kneighbors_graph
12- from sklearn .neighbors import KDTree
13-
14-
15-
16- def compute_laplacian (G , normalization = False ):
17-
18- if normalization :
19- laplacian = sparse .csr_matrix (nx .normalized_laplacian_matrix (G ), dtype = np .float64 )
20- else :
21- laplacian = sparse .csr_matrix (nx .laplacian_matrix (G ), dtype = np .float64 )
22-
23- return laplacian
2412
2513
2614def compute_connection_laplacian (G , R , normalization = None ):
27- r """Connection Laplacian
15+ """Connection Laplacian
2816
2917 Args:
3018 data: Pytorch geometric data object.
3119 R (nxnxdxd): Connection matrices between all pairs of nodes. Default is None,
3220 in case of a global coordinate system.
33- normalization: None, 'sym', ' rw'
21+ normalization: None, 'rw'
3422 1. None: No normalization
3523 :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
3624
37- 2. "sym"`: Symmetric normalization
38- :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
39- \mathbf{D}^{-1/2}`
40-
41- 3. "rw"`: Random-walk normalization
25+ 2. "rw"`: Random-walk normalization
4226 :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
4327
4428 Returns:
@@ -65,12 +49,20 @@ def compute_connection_laplacian(G, R, normalization=None):
6549 deg_inv = deg_inv .repeat (dim , axis = 0 )
6650 Lc = sparse .diags (deg_inv , 0 , format = 'csr' ) @ Lc
6751
68- elif normalization == "sym" :
69- raise NotImplementedError
70-
7152 return Lc
7253
7354
55+ def compute_laplacian (G , normalization = False ):
56+ """Laplacian. Used as a helper function to compute_connection_laplacian()"""
57+
58+ if normalization :
59+ laplacian = sparse .csr_matrix (nx .normalized_laplacian_matrix (G ), dtype = np .float64 )
60+ else :
61+ laplacian = sparse .csr_matrix (nx .laplacian_matrix (G ), dtype = np .float64 )
62+
63+ return laplacian
64+
65+
7466def compute_spectrum (laplacian , n_eigenpairs = None , dtype = tf .float64 ):
7567
7668 if n_eigenpairs is None :
@@ -88,32 +80,6 @@ def compute_spectrum(laplacian, n_eigenpairs=None, dtype=tf.float64):
8880 return evals , evecs
8981
9082
91- def sample_from_convex_hull (points , num_samples , k = 5 ):
92-
93- tree = scipy .spatial .KDTree (points )
94-
95- if num_samples > len (points ):
96- num_samples = len (points )
97-
98- sample_points = np .random .choice (len (points ), size = num_samples , replace = False )
99- sample_points = points [sample_points ]
100-
101- # Generate samples
102- samples = []
103- for current_point in sample_points :
104- _ , nn_ind = tree .query (current_point , k = k , p = 2 )
105- nn_hull = points [nn_ind ]
106-
107- barycentric_coords = np .random .uniform (size = nn_hull .shape [0 ])
108- barycentric_coords /= np .sum (barycentric_coords )
109-
110- current_point = np .sum (nn_hull .T * barycentric_coords , axis = 1 )
111-
112- samples .append (current_point )
113-
114- return np .array (samples )
115-
116-
11783def manifold_dimension (Sigma , frac_explained = 0.9 ):
11884 """Estimate manifold dimension based on singular vectors"""
11985
@@ -134,7 +100,14 @@ def manifold_dimension(Sigma, frac_explained=0.9):
134100def manifold_graph (X , typ = 'knn' , n_neighbors = 5 ):
135101 """Fit graph over a pointset X"""
136102 if typ == 'knn' :
137- A = kneighbors_graph (X , n_neighbors , mode = 'connectivity' , metric = 'minkowski' , p = 2 , metric_params = None , include_self = False , n_jobs = None )
103+ A = kneighbors_graph (X ,
104+ n_neighbors ,
105+ mode = 'connectivity' ,
106+ metric = 'minkowski' ,
107+ p = 2 ,
108+ metric_params = None ,
109+ include_self = False ,
110+ n_jobs = None )
138111 A += sparse .eye (A .shape [0 ])
139112 G = nx .from_scipy_sparse_array (A )
140113
@@ -150,53 +123,6 @@ def manifold_graph(X, typ = 'knn', n_neighbors=5):
150123 return G
151124
152125
153- def find_nn (x_query , X , nn = 3 , r = None ):
154- """
155- Find nearest neighbors of a point on the manifold
156-
157- Parameters
158- ----------
159- ind_query : 2d np array, list[2d np array]
160- Index of points whose neighbors are needed.
161- x : nxd array (dimensions are columns!)
162- Coordinates of n points on a manifold in d-dimensional space.
163- nn : int, optional
164- Number of nearest neighbors. The default is 1.
165-
166- Returns
167- -------
168- dist : list[list]
169- Distance of nearest neighbors.
170- ind : list[list]
171- Index of nearest neighbors.
172-
173- """
174-
175- #Fit neighbor estimator object
176- kdt = KDTree (X , leaf_size = 30 , metric = 'euclidean' )
177-
178- if r is not None :
179- ind , dist = kdt .query_radius (x_query , r = r , return_distance = True , sort_results = True )
180- ind = ind [0 ]
181- dist = dist [0 ]
182- else :
183- # apparently, the outputs are reversed here compared to query_radius()
184- dist , ind = kdt .query (x_query , k = nn )
185-
186- return dist , ind .flatten ()
187-
188-
189- def closest_manifold_point (x_query , d , nn = 3 ):
190- dist , ind = find_nn (x_query , d .vertices , nn = nn )
191- w = 1 / (dist .T + 0.00001 )
192- w /= w .sum ()
193- positional_encoding = d .evecs_Lc .reshape (d .n , - 1 )
194- pe_manifold = (positional_encoding [ind ]* w ).sum (0 , keepdims = True )
195- x_manifold = d .vertices [ind ]
196-
197- return x_manifold , pe_manifold
198-
199-
200126def furthest_point_sampling (x , N = None , stop_crit = 0.1 , start_idx = 0 ):
201127 """A greedy O(N^2) algorithm to do furthest points sampling
202128
@@ -217,8 +143,6 @@ def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):
217143 n = D .shape [0 ] if N is None else N
218144 diam = D .max ()
219145
220- start_idx = 5
221-
222146 perm = np .zeros (n , dtype = np .int32 )
223147 perm [0 ] = start_idx
224148 lambdas = np .zeros (n )
@@ -239,23 +163,14 @@ def furthest_point_sampling(x, N=None, stop_crit=0.1, start_idx=0):
239163
240164
241165def project_to_manifold (x , gauges ):
166+ """Project vectors to local coordinates over manifold"""
242167 coeffs = np .einsum ("bij,bi->bj" , gauges , x )
243168 return np .einsum ("bj,bij->bi" , coeffs , gauges )
244169
245170
246- def project_to_local_frame (x , gauges , reverse = False ):
171+ def express_in_local_frame (x , gauges , reverse = False ):
172+ """Express vectors in local coordinates over manifold"""
247173 if reverse :
248174 return np .einsum ("bji,bi->bj" , gauges , x )
249175 else :
250- return np .einsum ("bij,bi->bj" , gauges , x )
251-
252-
253- def local_to_global (x , gauges ):
254- return np .einsum ("bj,bij->bi" , x , gauges )
255-
256-
257- def node_eigencoords (node_ind , evecs_Lc , dim ):
258- r , c = evecs_Lc .shape
259- evecs_Lc = evecs_Lc .reshape (- 1 , c * dim )
260- node_coords = evecs_Lc [node_ind ]
261- return node_coords .reshape (- 1 , c )
176+ return np .einsum ("bij,bi->bj" , gauges , x )
0 commit comments