Skip to content

Commit 5ba3791

Browse files
committed
Merge branch 'master' of https://github.com/CenterBioML/stochman
2 parents dbd657e + 6736f5d commit 5ba3791

File tree

1 file changed

+86
-36
lines changed

1 file changed

+86
-36
lines changed

stochman/curves.py

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def plot(self, t0: float = 0.0, t1: float = 1.0, N: int = 100, *plot_args, **plo
7878
return figs
7979
if points.shape[-1] == 2:
8080
for b in range(points.shape[0]):
81-
fig = plt.plot(points[b, :, 0], points[b, :, 1], "-", *plot_args, **plot_kwargs)
81+
fig = plt.plot(points[b, :, 0], points[b, :, 1], *plot_args, **plot_kwargs)
8282
figs.append(fig)
8383
return figs
8484

@@ -140,10 +140,11 @@ def closure():
140140
L.backward()
141141
return L
142142

143-
for _ in range(num_steps):
144-
loss = opt.step(closure=closure)
145-
if torch.max(torch.abs(self.params.grad)) < threshold:
146-
break
143+
with torch.enable_grad():
144+
for _ in range(num_steps):
145+
loss = opt.step(closure=closure)
146+
if torch.max(torch.abs(self.params.grad)) < threshold:
147+
break
147148
return loss
148149

149150

@@ -162,43 +163,48 @@ def _init_params(self, params, *args, **kwargs) -> None:
162163
self.register_buffer(
163164
"t",
164165
torch.linspace(0, 1, self._num_nodes, dtype=self.begin.dtype)[1:-1]
165-
.reshape(-1, 1, 1)
166-
.repeat(1, *self.begin.shape), # (_num_nodes-2)xBxD
166+
.view(1, -1, 1)
167+
.expand(self.begin.shape[0], -1, self.begin.shape[1]), # Bx(_num_nodes-2)xD
167168
)
168169
if params is None:
169-
params = self.t * self.end.unsqueeze(0) + (1 - self.t) * self.begin.unsqueeze(
170-
0
171-
) # (_num_nodes)xBxD
170+
params = self.t * self.end.unsqueeze(1) + \
171+
(1 - self.t) * self.begin.unsqueeze(1) # Bx(_num_nodes)xD
172172
if self._requires_grad:
173173
self.register_parameter("params", nn.Parameter(params))
174174
else:
175175
self.register_buffer("params", params)
176176

177177
def forward(self, t: torch.Tensor) -> torch.Tensor:
178-
start_nodes = torch.cat((self.begin.unsqueeze(0), self.params)) # (num_edges)xBxD
179-
end_nodes = torch.cat((self.params, self.end.unsqueeze(0))) # (num_edges)xBxD
180-
num_edges, B, D = start_nodes.shape
178+
start_nodes = torch.cat((self.begin.unsqueeze(1), self.params), dim=1) # Bx(num_edges)xD
179+
end_nodes = torch.cat((self.params, self.end.unsqueeze(1)), dim=1) # Bx(num_edges)xD
180+
B, num_edges, D = start_nodes.shape
181181
t0 = torch.cat(
182182
(
183-
torch.zeros(1, B, D, dtype=self.t.dtype, device=self.device),
183+
torch.zeros(B, 1, D, dtype=self.t.dtype, device=self.device),
184184
self.t,
185-
torch.ones(1, B, D, dtype=self.t.dtype, device=self.device),
186-
)
187-
)
188-
a = (end_nodes - start_nodes) / (t0[1:] - t0[:-1]) # (num_edges)xBxD
189-
b = start_nodes - a * t0[:-1] # (num_edges)xBxD
190-
185+
torch.ones(B, 1, D, dtype=self.t.dtype, device=self.device),
186+
),
187+
dim=1
188+
) # Bx(num_nodes)xD
189+
a = (end_nodes - start_nodes) / (t0[:, 1:] - t0[:, :-1]) # Bx(num_edges)xD
190+
b = start_nodes - a * t0[:, :-1] # Bx(num_edges)xD
191+
192+
if t.ndim == 1:
193+
tt = t.view((1, -1)).expand(B, -1) # Bx|t|
194+
elif t.ndim == 2:
195+
tt = t # Bx|t|
196+
else:
197+
raise Exception('t must have at most 2 dimensions')
191198
idx = (
192-
torch.floor(t.flatten() * num_edges).clamp(min=0, max=num_edges - 1).long()
193-
) # use this if nodes are equi-distant
194-
tt = t.view((-1, 1, 1)).expand(-1, B, D)
195-
result = a[idx] * tt + b[idx] # (num_edges)xBxD
196-
return result.permute(1, 0, 2).squeeze(0) # Bx(num_edges)xD
199+
torch.floor(tt * num_edges).clamp(min=0, max=num_edges - 1).long() # Bx|t|
200+
).unsqueeze(2).repeat(1, 1, D) # Bx|t|xD, this assumes that nodes are equi-distant
201+
result = torch.gather(a, 1, idx) * tt.unsqueeze(2) + torch.gather(b, 1, idx) # Bx|t|xD
202+
return result
197203

198204
def __getitem__(self, indices: int) -> "DiscreteCurve":
199-
params = self.params[:, indices]
205+
params = self.params[indices]
200206
if params.dim() == 2:
201-
params = params.unsqueeze(1)
207+
params = params.unsqueeze(0)
202208
C = DiscreteCurve(
203209
begin=self.begin[indices],
204210
end=self.end[indices],
@@ -209,7 +215,49 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
209215
return C
210216

211217
def __setitem__(self, indices, curves) -> None:
212-
self.params[:, indices] = curves.params.squeeze()
218+
self.params[indices] = curves.params.squeeze()
219+
220+
# def constant_speed(
221+
# self, metric=None, t: Optional[torch.Tensor] = None
222+
# ) -> Tuple[torch.Tensor, torch.Tensor]:
223+
# """
224+
# Reparametrize the curve to have constant speed.
225+
226+
# Optional input:
227+
# metric: the Manifold under which the curve should have constant speed.
228+
# If None then the Euclidean metric is applied.
229+
# Default: None.
230+
231+
# Note: It is not possible to back-propagate through this function.
232+
# """
233+
# from stochman import CubicSpline
234+
235+
# with torch.no_grad():
236+
# if t is None:
237+
# t = torch.linspace(0, 1, 100) # N
238+
# Ct = self(t) # NxD or BxNxD
239+
# if Ct.dim() == 2:
240+
# Ct.unsqueeze_(0) # BxNxD
241+
# B, N, D = Ct.shape
242+
# delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
243+
# if metric is None:
244+
# local_len = delta.norm(dim=2) # Bx(N-1)
245+
# else:
246+
# local_len = (
247+
# metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
248+
# .view(B, N - 1)
249+
# .sqrt()
250+
# ) # Bx(N-1)
251+
# cs = local_len.cumsum(dim=1) # Bx(N-1)
252+
# zero = torch.zeros(B, 1) # Bx1 -- XXX: missing dtype and device
253+
# one = torch.ones(B, 1) # Bx1 -- XXX: ditto
254+
# new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
255+
# S = CubicSpline(zero, one)
256+
# _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
257+
# new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
258+
259+
# from IPython import embed; embed()
260+
# return new_t, Ct
213261

214262
def tospline(self):
215263
from stochman import CubicSpline
@@ -220,7 +268,7 @@ def tospline(self):
220268
num_nodes=self._num_nodes,
221269
requires_grad=self._requires_grad,
222270
)
223-
c.fit(self.t[:, 0, 0], self.params.squeeze(1))
271+
_ = c.fit(self.t[0, :, 0], self.params)
224272
return c
225273

226274

@@ -329,8 +377,6 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
329377
if no_batch:
330378
t = t.expand(coeffs.shape[0], -1) # Bx|t|
331379
retval = self._eval_polynomials(t, coeffs) # Bx|t|xD
332-
# tt = t.view((-1, 1)).unsqueeze(0).expand(retval.shape[0], -1, -1) # Bx|t|x1
333-
# retval += (1-tt).bmm(self.begin.unsqueeze(1)) + tt.bmm(self.end.unsqueeze(1)) # Bx|t|xD
334380
retval += self._eval_straight_line(t)
335381
if no_batch and retval.shape[0] == 1:
336382
retval.squeeze_(0) # |t|xD
@@ -393,7 +439,6 @@ def constant_speed(
393439
Default: None.
394440
395441
Note: It is not possible to back-propagate through this function.
396-
Note: This function does currently not support batching.
397442
"""
398443
with torch.no_grad():
399444
if t is None:
@@ -413,16 +458,21 @@ def constant_speed(
413458
) # Bx(N-1)
414459
cs = local_len.cumsum(dim=1) # Bx(N-1)
415460
new_t = torch.cat((torch.zeros(B, 1), cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
416-
with torch.enable_grad():
417-
_ = self.fit(new_t, Ct)
461+
_ = self.fit(new_t, Ct)
418462
return new_t, Ct
419463

420-
def todiscrete(self):
464+
def todiscrete(self, num_nodes=None):
421465
from stochman import DiscreteCurve
422466

467+
if num_nodes is None:
468+
num_nodes = self._num_nodes
469+
t = torch.linspace(0, 1, num_nodes)[1:-1] # (num_nodes-2)
470+
Ct = self(t) # Bx(num_nodes-2)xD
471+
423472
return DiscreteCurve(
424473
begin=self.begin,
425474
end=self.end,
426-
num_nodes=self._num_nodes,
475+
num_nodes=num_nodes,
427476
requires_grad=self._requires_grad,
477+
params=Ct,
428478
)

0 commit comments

Comments
 (0)