Skip to content

Commit dfddc76

Browse files
committed
DiscreteCurve.forward: drop batch dim if it is 1
1 parent 5ba3791 commit dfddc76

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

stochman/curves.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
199199
torch.floor(tt * num_edges).clamp(min=0, max=num_edges - 1).long() # Bx|t|
200200
).unsqueeze(2).repeat(1, 1, D) # Bx|t|xD, this assumes that nodes are equi-distant
201201
result = torch.gather(a, 1, idx) * tt.unsqueeze(2) + torch.gather(b, 1, idx) # Bx|t|xD
202+
if B == 1:
203+
result = result.squeeze(0) # |t|xD
202204
return result
203205

204206
def __getitem__(self, indices: int) -> "DiscreteCurve":

0 commit comments

Comments
 (0)