Skip to content

Commit 7b0c82d

Browse files
committed
Rename argument curve to init_curve to be compatible with the smooth interface
1 parent 29c752b commit 7b0c82d

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

stochman/discretized_manifold.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _path_to_curve(self, path, mesh, curve, device=None):
266266

267267
curve.fit(t, coordinates)
268268

269-
def connecting_geodesic(self, p1, p2, curve=None):
269+
def connecting_geodesic(self, p1, p2, init_curve=None):
270270
"""Compute the shortest path on the discretized manifold and fit
271271
a smooth curve to the resulting discrete curve.
272272
@@ -276,7 +276,8 @@ def connecting_geodesic(self, p1, p2, curve=None):
276276
p2: a torch Tensor corresponding to another latent point.
277277
278278
Optional input:
279-
curve: a curve that should be fitted to the discrete graph
279+
init_curve:
280+
a curve that should be fitted to the discrete graph
280281
geodesic. By default this is None and a CubicSpline
281282
with default paramaters will be constructed.
282283
@@ -297,11 +298,12 @@ def connecting_geodesic(self, p1, p2, curve=None):
297298

298299
single_source = min(batch1, batch2) == 1 and B > 1
299300

300-
if curve is None:
301+
if init_curve is None:
301302
curve = CubicSpline(p1, p2)
302303
else:
303-
curve.begin = p1
304-
curve.end = p2
304+
curve = init_curve
305+
curve.begin = p1 if batch1 > 1 else p1.repeat(B, 1)
306+
curve.end = p2 if batch2 > 1 else p2.repeat(B, 1)
305307

306308
mesh = torch.meshgrid(*self.grid, indexing="ij")
307309

0 commit comments

Comments
 (0)