Skip to content

Commit a537934

Browse files
author
Beat Buesser
committed
Account for object arrays
Signed-off-by: Beat Buesser <[email protected]>
1 parent e542bf7 commit a537934

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

art/estimators/object_tracking/pytorch_goturn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,11 +665,12 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
665665

666666
for i in range(x.shape[0]):
667667
if isinstance(x, np.ndarray):
668-
x_i = torch.from_numpy(x[[i]]).to(self.device)
668+
x_i = torch.from_numpy(x[i]).to(self.device)
669669
else:
670-
x_i = x[[i]].to(self.device)
670+
x_i = x[i].to(self.device)
671671

672672
# Apply preprocessing
673+
x_i = torch.unsqueeze(x_i, dim=0)
673674
x_i, _ = self._apply_preprocessing(x_i, y=None, fit=False, no_grad=False)
674675
x_i = torch.squeeze(x_i)
675676

0 commit comments

Comments
 (0)