Skip to content

Commit 6b2b910

Browse files
author
Beat Buesser
committed
Update device transfers for PyTorchGoturn
Signed-off-by: Beat Buesser <[email protected]>
1 parent 0794bef commit 6b2b910

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

art/estimators/object_tracking/pytorch_goturn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _get_losses(
233233
labels_t = y_preprocessed # type: ignore
234234

235235
if isinstance(y[0]["boxes"], np.ndarray):
236-
y_init = torch.from_numpy(y[0]["boxes"])
236+
y_init = torch.from_numpy(y[0]["boxes"]).to(self.device)
237237
else:
238238
y_init = y[0]["boxes"]
239239

@@ -632,7 +632,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
632632
if y_init is None:
633633
raise ValueError("y_init is a required argument for method `predict`.")
634634

635-
y_init = torch.from_numpy(y_init).to(self._device).float()
635+
y_init = torch.from_numpy(y_init).to(self._device).float().to(self.device)
636636

637637
predictions = list()
638638

0 commit comments

Comments
 (0)