@@ -55,7 +55,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
55
55
56
56
dim = len (grid )
57
57
if len (grid ) != 2 :
58
- raise Exception (' Currently we only support 2D grids -- sorry!' )
58
+ raise Exception (" Currently we only support 2D grids -- sorry!" )
59
59
60
60
# Add nodes to graph
61
61
xsize , ysize = len (grid [0 ]), len (grid [1 ])
@@ -64,7 +64,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
64
64
65
65
point_set = torch .cartesian_prod (
66
66
torch .linspace (0 , xsize - 1 , xsize , dtype = torch .long ),
67
- torch .linspace (0 , ysize - 1 , ysize , dtype = torch .long )
67
+ torch .linspace (0 , ysize - 1 , ysize , dtype = torch .long ),
68
68
) # (big)x2
69
69
70
70
point_sets = [] # these will be [N, 2] matrices of index points
@@ -93,8 +93,8 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
93
93
t = torch .linspace (0 , 1 , 2 )
94
94
for ps , nf in zip (point_sets , neighbour_funcs ):
95
95
for i in range (ceil (ps .shape [0 ] / batch_size )):
96
- x = ps [batch_size * i : batch_size * (i + 1 ), 0 ]
97
- y = ps [batch_size * i : batch_size * (i + 1 ), 1 ]
96
+ x = ps [batch_size * i : batch_size * (i + 1 ), 0 ]
97
+ y = ps [batch_size * i : batch_size * (i + 1 ), 1 ]
98
98
xn , yn = nf [0 ](x ), nf [1 ](y )
99
99
100
100
bs = x .shape [0 ] # may be different from batch size for the last batch
@@ -132,7 +132,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
132
132
self .__metric__ = M .view ([* self .grid_size , d , d ]) # e.g. (xsize)x(ysize)x(d)x(d)
133
133
134
134
# Compute interpolation weights. We use the mean function of a GP regressor.
135
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
135
+ mesh = torch .meshgrid (* self .grid , indexing = "ij" )
136
136
grid_points = torch .cat (
137
137
[m .unsqueeze (- 1 ) for m in mesh ], dim = - 1
138
138
) # e.g. 100x100x2 a 2D grid with 100 points in each dim
@@ -145,6 +145,7 @@ def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise
145
145
) # (num_grid)x(d²) or (num_grid)x(d)
146
146
except :
147
147
import warnings
148
+
148
149
warnings .warn ("It appears that your model does not implement a metric." )
149
150
# XXX: Down the road, we should be able to estimate the metric from the observed distances
150
151
@@ -191,9 +192,9 @@ def _grid_dist2(self, p):
191
192
"""
192
193
193
194
dist2 = torch .zeros (p .shape [0 ], self .G .number_of_nodes ())
194
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
195
+ mesh = torch .meshgrid (* self .grid , indexing = "ij" )
195
196
for mesh_dim , dim in zip (mesh , range (len (self .grid ))):
196
- dist2 += (p [:, dim ].view (- 1 , 1 ) - mesh_dim .reshape (1 , - 1 ))** 2
197
+ dist2 += (p [:, dim ].view (- 1 , 1 ) - mesh_dim .reshape (1 , - 1 )) ** 2
197
198
return dist2
198
199
199
200
def _kernel (self , p ):
@@ -205,12 +206,12 @@ def _kernel(self, p):
205
206
Output:
206
207
val: a torch Tensor with the kernel values.
207
208
"""
208
- lengthscales = [(g [1 ] - g [0 ])** 2 for g in self .grid ]
209
+ lengthscales = [(g [1 ] - g [0 ]) ** 2 for g in self .grid ]
209
210
210
211
dist2 = torch .zeros (p .shape [0 ], self .G .number_of_nodes ())
211
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
212
+ mesh = torch .meshgrid (* self .grid , indexing = "ij" )
212
213
for mesh_dim , dim in zip (mesh , range (len (self .grid ))):
213
- dist2 += (p [:, dim ].view (- 1 , 1 ) - mesh_dim .reshape (1 , - 1 ))** 2 / lengthscales [dim ]
214
+ dist2 += (p [:, dim ].view (- 1 , 1 ) - mesh_dim .reshape (1 , - 1 )) ** 2 / lengthscales [dim ]
214
215
215
216
return torch .exp (- dist2 )
216
217
@@ -241,9 +242,9 @@ def shortest_path(self, p1, p2):
241
242
"""
242
243
idx1 = self ._grid_point (p1 )
243
244
idx2 = self ._grid_point (p2 )
244
- path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = ' weight' ) # list with N elements
245
+ path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = " weight" ) # list with N elements
245
246
# coordinates = self.grid.view(self.grid.shape[0], -1)[:, path] # (dim)xN
246
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
247
+ mesh = torch .meshgrid (* self .grid , indexing = "ij" )
247
248
raw_coordinates = [m .flatten ()[path ].view (1 , - 1 ) for m in mesh ]
248
249
coordinates = torch .cat (raw_coordinates , dim = 0 ) # (dim)xN
249
250
N = len (path )
@@ -252,7 +253,7 @@ def shortest_path(self, p1, p2):
252
253
curve .parameters [:, :] = coordinates [:, 1 :- 1 ].t ()
253
254
dist = 0
254
255
for i in range (N - 1 ):
255
- dist += self .G .edges [path [i ], path [i + 1 ]][' weight' ]
256
+ dist += self .G .edges [path [i ], path [i + 1 ]][" weight" ]
256
257
return curve , dist
257
258
258
259
def connecting_geodesic (self , p1 , p2 , curve = None ):
@@ -282,7 +283,7 @@ def connecting_geodesic(self, p1, p2, curve=None):
282
283
p2 = p2 .unsqueeze (0 ) # 1xD
283
284
B = p1 .shape [0 ]
284
285
if p1 .shape != p2 .shape :
285
- raise NameError (' shape mismatch' )
286
+ raise NameError (" shape mismatch" )
286
287
287
288
if curve is None :
288
289
curve = CubicSpline (p1 , p2 )
@@ -295,10 +296,10 @@ def connecting_geodesic(self, p1, p2, curve=None):
295
296
idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
296
297
idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
297
298
path = nx .shortest_path (
298
- self .G , source = idx1 , target = idx2 , weight = ' weight'
299
+ self .G , source = idx1 , target = idx2 , weight = " weight"
299
300
) # list with N elements
300
- weights = [self .G .edges [path [k ], path [k + 1 ]][' weight' ] for k in range (len (path ) - 1 )]
301
- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
301
+ weights = [self .G .edges [path [k ], path [k + 1 ]][" weight" ] for k in range (len (path ) - 1 )]
302
+ mesh = torch .meshgrid (* self .grid , indexing = "ij" )
302
303
raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
303
304
coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
304
305
t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
0 commit comments