@@ -266,7 +266,7 @@ def _path_to_curve(self, path, mesh, curve, device=None):
266
266
267
267
curve .fit (t , coordinates )
268
268
269
- def connecting_geodesic (self , p1 , p2 , curve = None ):
269
+ def connecting_geodesic (self , p1 , p2 , init_curve = None ):
270
270
"""Compute the shortest path on the discretized manifold and fit
271
271
a smooth curve to the resulting discrete curve.
272
272
@@ -276,7 +276,8 @@ def connecting_geodesic(self, p1, p2, curve=None):
276
276
p2: a torch Tensor corresponding to another latent point.
277
277
278
278
Optional input:
279
- curve: a curve that should be fitted to the discrete graph
279
+ init_curve:
280
+ a curve that should be fitted to the discrete graph
280
281
geodesic. By default this is None and a CubicSpline
281
282
with default paramaters will be constructed.
282
283
@@ -297,11 +298,12 @@ def connecting_geodesic(self, p1, p2, curve=None):
297
298
298
299
single_source = min (batch1 , batch2 ) == 1 and B > 1
299
300
300
- if curve is None :
301
+ if init_curve is None :
301
302
curve = CubicSpline (p1 , p2 )
302
303
else :
303
- curve .begin = p1
304
- curve .end = p2
304
+ curve = init_curve
305
+ curve .begin = p1 if batch1 > 1 else p1 .repeat (B , 1 )
306
+ curve .end = p2 if batch2 > 1 else p2 .repeat (B , 1 )
305
307
306
308
mesh = torch .meshgrid (* self .grid , indexing = "ij" )
307
309
0 commit comments