Skip to content

Commit 9d3c1bc

Browse files
committed
Merge branch 'master' of https://github.com/CenterBioML/stochman
2 parents 02e3023 + 6a91d6a commit 9d3c1bc

File tree

1 file changed

+47
-46
lines changed

1 file changed

+47
-46
lines changed

stochman/curves.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _init_params(self, params, *args, **kwargs) -> None:
168168
)
169169
if params is None:
170170
params = self.t * self.end.unsqueeze(1) + \
171-
(1 - self.t) * self.begin.unsqueeze(1) # Bx(_num_nodes)xD
171+
(1 - self.t) * self.begin.unsqueeze(1) # Bx(_num_nodes)xD
172172
if self._requires_grad:
173173
self.register_parameter("params", nn.Parameter(params))
174174
else:
@@ -185,7 +185,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
185185
torch.ones(B, 1, D, dtype=self.t.dtype, device=self.device),
186186
),
187187
dim=1
188-
) # Bx(num_nodes)xD
188+
) # Bx(num_nodes)xD
189189
a = (end_nodes - start_nodes) / (t0[:, 1:] - t0[:, :-1]) # Bx(num_edges)xD
190190
b = start_nodes - a * t0[:, :-1] # Bx(num_edges)xD
191191

@@ -199,6 +199,8 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
199199
torch.floor(tt * num_edges).clamp(min=0, max=num_edges - 1).long() # Bx|t|
200200
).unsqueeze(2).repeat(1, 1, D) # Bx|t|xD, this assumes that nodes are equi-distant
201201
result = torch.gather(a, 1, idx) * tt.unsqueeze(2) + torch.gather(b, 1, idx) # Bx|t|xD
202+
if B == 1:
203+
result = result.squeeze(0) # |t|xD
202204
return result
203205

204206
def __getitem__(self, indices: int) -> "DiscreteCurve":
@@ -217,47 +219,46 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
217219
def __setitem__(self, indices, curves) -> None:
218220
self.params[indices] = curves.params.squeeze()
219221

220-
# def constant_speed(
221-
# self, metric=None, t: Optional[torch.Tensor] = None
222-
# ) -> Tuple[torch.Tensor, torch.Tensor]:
223-
# """
224-
# Reparametrize the curve to have constant speed.
225-
226-
# Optional input:
227-
# metric: the Manifold under which the curve should have constant speed.
228-
# If None then the Euclidean metric is applied.
229-
# Default: None.
230-
231-
# Note: It is not possible to back-propagate through this function.
232-
# """
233-
# from stochman import CubicSpline
234-
235-
# with torch.no_grad():
236-
# if t is None:
237-
# t = torch.linspace(0, 1, 100) # N
238-
# Ct = self(t) # NxD or BxNxD
239-
# if Ct.dim() == 2:
240-
# Ct.unsqueeze_(0) # BxNxD
241-
# B, N, D = Ct.shape
242-
# delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
243-
# if metric is None:
244-
# local_len = delta.norm(dim=2) # Bx(N-1)
245-
# else:
246-
# local_len = (
247-
# metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
248-
# .view(B, N - 1)
249-
# .sqrt()
250-
# ) # Bx(N-1)
251-
# cs = local_len.cumsum(dim=1) # Bx(N-1)
252-
# zero = torch.zeros(B, 1) # Bx1 -- XXX: missing dtype and device
253-
# one = torch.ones(B, 1) # Bx1 -- XXX: ditto
254-
# new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
255-
# S = CubicSpline(zero, one)
256-
# _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
257-
# new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
258-
259-
# from IPython import embed; embed()
260-
# return new_t, Ct
222+
def constant_speed(
223+
self, metric=None, t: Optional[torch.Tensor] = None
224+
) -> Tuple[torch.Tensor, torch.Tensor]:
225+
"""
226+
Reparametrize the curve to have constant speed.
227+
228+
Optional input:
229+
metric: the Manifold under which the curve should have constant speed.
230+
If None then the Euclidean metric is applied.
231+
Default: None.
232+
233+
Note: It is not possible to back-propagate through this function.
234+
"""
235+
from stochman import CubicSpline
236+
237+
with torch.no_grad():
238+
if t is None:
239+
t = torch.linspace(0, 1, 100) # N
240+
Ct = self(t) # NxD or BxNxD
241+
if Ct.ndim == 2:
242+
Ct.unsqueeze_(0) # BxNxD
243+
B, N, D = Ct.shape
244+
delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
245+
if metric is None:
246+
local_len = delta.norm(dim=2) # Bx(N-1)
247+
else:
248+
local_len = (
249+
metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
250+
.view(B, N - 1)
251+
.sqrt()
252+
) # Bx(N-1)
253+
cs = local_len.cumsum(dim=1) # Bx(N-1)
254+
zero = torch.zeros(B, 1, dtype=cs.dtype, device=cs.device) # Bx1
255+
one = torch.ones(B, 1, dtype=cs.dtype, device=cs.device) # Bx1
256+
new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
257+
S = CubicSpline(zero, one)
258+
_ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
259+
new_params = self(S(self.t[:, :, 0]).squeeze(-1)) # Bx(num_nodes-2)xD
260+
self.params = nn.Parameter(new_params)
261+
return new_t, Ct
261262

262263
def tospline(self):
263264
from stochman import CubicSpline
@@ -466,9 +467,9 @@ def todiscrete(self, num_nodes=None):
466467

467468
if num_nodes is None:
468469
num_nodes = self._num_nodes
469-
t = torch.linspace(0, 1, num_nodes)[1:-1] # (num_nodes-2)
470-
Ct = self(t) # Bx(num_nodes-2)xD
471-
470+
t = torch.linspace(0, 1, num_nodes)[1:-1] # (num_nodes-2)
471+
Ct = self(t) # Bx(num_nodes-2)xD
472+
472473
return DiscreteCurve(
473474
begin=self.begin,
474475
end=self.end,

0 commit comments

Comments
 (0)