@@ -225,6 +225,8 @@ def _grid_point(self, p):
225
225
idx: an integer correponding to the node index of
226
226
the nearest point on the grid.
227
227
"""
228
+ if p .ndim == 1 :
229
+ p = p .unsqueeze (0 )
228
230
return self ._grid_dist2 (p ).argmin ().item ()
229
231
230
232
def shortest_path (self , p1 , p2 ):
@@ -256,6 +258,14 @@ def shortest_path(self, p1, p2):
256
258
dist += self .G .edges [path [i ], path [i + 1 ]]["weight" ]
257
259
return curve , dist
258
260
261
+ def _path_to_curve (self , path , mesh , curve , device = None ):
262
+ weights = [self .G .edges [path [k ], path [k + 1 ]]["weight" ] for k in range (len (path ) - 1 )]
263
+ raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
264
+ coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
265
+ t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
266
+
267
+ curve .fit (t , coordinates )
268
+
259
269
def connecting_geodesic (self , p1 , p2 , curve = None ):
260
270
"""Compute the shortest path on the discretized manifold and fit
261
271
a smooth curve to the resulting discrete curve.
@@ -277,33 +287,47 @@ def connecting_geodesic(self, p1, p2, curve=None):
277
287
curve input.
278
288
"""
279
289
device = p1 .device
280
- if p1 .ndim == 1 :
281
- p1 = p1 .unsqueeze (0 ) # 1xD
282
- if p2 .ndim == 1 :
283
- p2 = p2 .unsqueeze (0 ) # 1xD
284
- B = p1 .shape [0 ]
285
- if p1 .shape != p2 .shape :
286
- raise NameError ("shape mismatch" )
290
+ batch1 = 1 if len (p1 .shape ) == 1 else p1 .shape [0 ]
291
+ batch2 = 1 if len (p2 .shape ) == 1 else p2 .shape [0 ]
292
+ B = max (batch1 , batch2 )
293
+ if batch1 == 1 :
294
+ p1 = p1 .view ((1 , - 1 )) # 1xD
295
+ if batch2 == 1 :
296
+ p2 = p2 .view ((1 , - 1 )) # 1xD
297
+
298
+ single_source = min (batch1 , batch2 ) == 1 and B > 1
287
299
288
300
if curve is None :
289
301
curve = CubicSpline (p1 , p2 )
290
302
else :
291
303
curve .begin = p1
292
304
curve .end = p2
293
305
294
- for b in range (B ):
306
+ mesh = torch .meshgrid (* self .grid , indexing = "ij" )
307
+
308
+ if single_source :
295
309
with torch .no_grad ():
296
- idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
297
- idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
310
+ if batch1 == 1 :
311
+ source_idx = self ._grid_point (p1 )
312
+ else :
313
+ source_idx = self ._grid_point (p2 )
314
+ paths = nx .single_source_dijkstra_path (
315
+ self .G , source = source_idx , weight = "weight"
316
+ ) # list of lists
317
+ for b in range (B ):
318
+ if batch1 == 1 :
319
+ idx = self ._grid_point (p2 [b ])
320
+ else :
321
+ idx = self ._grid_point (p1 [b ])
322
+ self ._path_to_curve (paths [idx ], mesh , curve [b ], device )
323
+ else :
324
+ for b in range (B ):
325
+ with torch .no_grad ():
326
+ idx1 = self ._grid_point (p1 [b ])
327
+ idx2 = self ._grid_point (p2 [b ])
298
328
path = nx .shortest_path (
299
329
self .G , source = idx1 , target = idx2 , weight = "weight"
300
330
) # list with N elements
301
- weights = [self .G .edges [path [k ], path [k + 1 ]]["weight" ] for k in range (len (path ) - 1 )]
302
- mesh = torch .meshgrid (* self .grid , indexing = "ij" )
303
- raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
304
- coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
305
- t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
306
-
307
- curve [b ].fit (t , coordinates )
331
+ self ._path_to_curve (path , mesh , curve [b ], device )
308
332
309
333
return curve , True
0 commit comments