Skip to content

Commit 29c752b

Browse files
committed
Support single source geodesics
1 parent 61bb75e commit 29c752b

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

stochman/discretized_manifold.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def _grid_point(self, p):
225225
idx: an integer correponding to the node index of
226226
the nearest point on the grid.
227227
"""
228+
if p.ndim == 1:
229+
p = p.unsqueeze(0)
228230
return self._grid_dist2(p).argmin().item()
229231

230232
def shortest_path(self, p1, p2):
@@ -256,6 +258,14 @@ def shortest_path(self, p1, p2):
256258
dist += self.G.edges[path[i], path[i + 1]]["weight"]
257259
return curve, dist
258260

261+
def _path_to_curve(self, path, mesh, curve, device=None):
262+
weights = [self.G.edges[path[k], path[k + 1]]["weight"] for k in range(len(path) - 1)]
263+
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
264+
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
265+
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
266+
267+
curve.fit(t, coordinates)
268+
259269
def connecting_geodesic(self, p1, p2, curve=None):
260270
"""Compute the shortest path on the discretized manifold and fit
261271
a smooth curve to the resulting discrete curve.
@@ -277,33 +287,47 @@ def connecting_geodesic(self, p1, p2, curve=None):
277287
curve input.
278288
"""
279289
device = p1.device
280-
if p1.ndim == 1:
281-
p1 = p1.unsqueeze(0) # 1xD
282-
if p2.ndim == 1:
283-
p2 = p2.unsqueeze(0) # 1xD
284-
B = p1.shape[0]
285-
if p1.shape != p2.shape:
286-
raise NameError("shape mismatch")
290+
batch1 = 1 if len(p1.shape) == 1 else p1.shape[0]
291+
batch2 = 1 if len(p2.shape) == 1 else p2.shape[0]
292+
B = max(batch1, batch2)
293+
if batch1 == 1:
294+
p1 = p1.view((1, -1)) # 1xD
295+
if batch2 == 1:
296+
p2 = p2.view((1, -1)) # 1xD
297+
298+
single_source = min(batch1, batch2) == 1 and B > 1
287299

288300
if curve is None:
289301
curve = CubicSpline(p1, p2)
290302
else:
291303
curve.begin = p1
292304
curve.end = p2
293305

294-
for b in range(B):
306+
mesh = torch.meshgrid(*self.grid, indexing="ij")
307+
308+
if single_source:
295309
with torch.no_grad():
296-
idx1 = self._grid_point(p1[b].unsqueeze(0))
297-
idx2 = self._grid_point(p2[b].unsqueeze(0))
310+
if batch1 == 1:
311+
source_idx = self._grid_point(p1)
312+
else:
313+
source_idx = self._grid_point(p2)
314+
paths = nx.single_source_dijkstra_path(
315+
self.G, source=source_idx, weight="weight"
316+
) # list of lists
317+
for b in range(B):
318+
if batch1 == 1:
319+
idx = self._grid_point(p2[b])
320+
else:
321+
idx = self._grid_point(p1[b])
322+
self._path_to_curve(paths[idx], mesh, curve[b], device)
323+
else:
324+
for b in range(B):
325+
with torch.no_grad():
326+
idx1 = self._grid_point(p1[b])
327+
idx2 = self._grid_point(p2[b])
298328
path = nx.shortest_path(
299329
self.G, source=idx1, target=idx2, weight="weight"
300330
) # list with N elements
301-
weights = [self.G.edges[path[k], path[k + 1]]["weight"] for k in range(len(path) - 1)]
302-
mesh = torch.meshgrid(*self.grid, indexing="ij")
303-
raw_coordinates = [m.flatten()[path[1:-1]].view(-1, 1) for m in mesh]
304-
coordinates = torch.cat(raw_coordinates, dim=1) # Nx(dim)
305-
t = torch.tensor(weights[:-1], device=device).cumsum(dim=0) / sum(weights)
306-
307-
curve[b].fit(t, coordinates)
331+
self._path_to_curve(path, mesh, curve[b], device)
308332

309333
return curve, True

0 commit comments

Comments
 (0)