Skip to content

Commit 61bb75e

Browse files
committed
Support having only end end-point being batched
1 parent 910b9fa commit 61bb75e

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

stochman/curves.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,32 @@ def __init__(
2121
self._num_nodes = num_nodes
2222
self._requires_grad = requires_grad
2323

24-
# register begin and end as buffers
25-
if len(begin.shape) == 1 or begin.shape[0] == 1:
26-
self.register_buffer("begin", begin.detach().view((1, -1))) # 1xD
24+
# if either begin or end only has one point, while the other has a batch
25+
# then we expand the singular point. End result is that both begin and
26+
# end should have shape BxD
27+
batch_begin = 1 if len(begin.shape) == 1 else begin.shape[0]
28+
batch_end = 1 if len(end.shape) == 1 else end.shape[0]
29+
if batch_begin == 1 and batch_end == 1:
30+
_begin = begin.detach().view((1, -1)) # 1xD
31+
_end = end.detach().view((1, -1)) # 1xD
32+
elif batch_begin == 1: # batch_end > 1
33+
_begin = begin.detach().view((1, -1)).repeat(batch_end, 1) # BxD
34+
_end = end.detach() # BxD
35+
elif batch_end == 1: # batch_begin > 1
36+
_begin = begin.detach() # BxD
37+
_end = end.detach().view((1, -1)).repeat(batch_begin, 1) # BxD
38+
elif batch_begin == batch_end:
39+
_begin = begin.detach() # BxD
40+
_end = end.detach() # BxD
2741
else:
28-
self.register_buffer("begin", begin.detach()) # BxD
42+
raise ValueError(
43+
"BasicCurve.__init__ requires begin and end points to have "
44+
"the same shape"
45+
)
2946

30-
if len(end.shape) == 1 or end.shape[0] == 1:
31-
self.register_buffer("end", end.detach().view((1, -1))) # 1xD
32-
else:
33-
self.register_buffer("end", end.detach()) # BxD
47+
# register begin and end as buffers
48+
self.register_buffer("begin", _begin) # BxD
49+
self.register_buffer("end", _end) # BxD
3450

3551
# overriden by child modules
3652
self._init_params(*args, **kwargs)

0 commit comments

Comments
 (0)