Skip to content

Commit 22cf106

Browse files
committed
DiscreteCurve.constant_speed: new function
1 parent dfddc76 commit 22cf106

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

stochman/curves.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -219,47 +219,46 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
219219
def __setitem__(self, indices, curves) -> None:
220220
self.params[indices] = curves.params.squeeze()
221221

222-
# def constant_speed(
223-
# self, metric=None, t: Optional[torch.Tensor] = None
224-
# ) -> Tuple[torch.Tensor, torch.Tensor]:
225-
# """
226-
# Reparametrize the curve to have constant speed.
227-
228-
# Optional input:
229-
# metric: the Manifold under which the curve should have constant speed.
230-
# If None then the Euclidean metric is applied.
231-
# Default: None.
232-
233-
# Note: It is not possible to back-propagate through this function.
234-
# """
235-
# from stochman import CubicSpline
236-
237-
# with torch.no_grad():
238-
# if t is None:
239-
# t = torch.linspace(0, 1, 100) # N
240-
# Ct = self(t) # NxD or BxNxD
241-
# if Ct.dim() == 2:
242-
# Ct.unsqueeze_(0) # BxNxD
243-
# B, N, D = Ct.shape
244-
# delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
245-
# if metric is None:
246-
# local_len = delta.norm(dim=2) # Bx(N-1)
247-
# else:
248-
# local_len = (
249-
# metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
250-
# .view(B, N - 1)
251-
# .sqrt()
252-
# ) # Bx(N-1)
253-
# cs = local_len.cumsum(dim=1) # Bx(N-1)
254-
# zero = torch.zeros(B, 1) # Bx1 -- XXX: missing dtype and device
255-
# one = torch.ones(B, 1) # Bx1 -- XXX: ditto
256-
# new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
257-
# S = CubicSpline(zero, one)
258-
# _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
259-
# new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
260-
261-
# from IPython import embed; embed()
262-
# return new_t, Ct
222+
def constant_speed(
223+
self, metric=None, t: Optional[torch.Tensor] = None
224+
) -> Tuple[torch.Tensor, torch.Tensor]:
225+
"""
226+
Reparametrize the curve to have constant speed.
227+
228+
Optional input:
229+
metric: the Manifold under which the curve should have constant speed.
230+
If None then the Euclidean metric is applied.
231+
Default: None.
232+
233+
Note: It is not possible to back-propagate through this function.
234+
"""
235+
from stochman import CubicSpline
236+
237+
with torch.no_grad():
238+
if t is None:
239+
t = torch.linspace(0, 1, 100) # N
240+
Ct = self(t) # NxD or BxNxD
241+
if Ct.ndim == 2:
242+
Ct.unsqueeze_(0) # BxNxD
243+
B, N, D = Ct.shape
244+
delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
245+
if metric is None:
246+
local_len = delta.norm(dim=2) # Bx(N-1)
247+
else:
248+
local_len = (
249+
metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
250+
.view(B, N - 1)
251+
.sqrt()
252+
) # Bx(N-1)
253+
cs = local_len.cumsum(dim=1) # Bx(N-1)
254+
zero = torch.zeros(B, 1, dtype=cs.dtype, device=cs.device) # Bx1
255+
one = torch.ones(B, 1, dtype=cs.dtype, device=cs.device) # Bx1
256+
new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
257+
S = CubicSpline(zero, one)
258+
_ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
259+
new_params = self(S(self.t[:, :, 0]).squeeze(-1)) # Bx(num_nodes-2)xD
260+
self.params = nn.Parameter(new_params)
261+
return new_t, Ct
263262

264263
def tospline(self):
265264
from stochman import CubicSpline

0 commit comments

Comments
 (0)