Skip to content

Commit 3b61114

Browse files
authored
Add device as arg to TensorDataset (#170)
1 parent 4745b20 commit 3b61114

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

cebra/data/datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ def __init__(
7373
continuous: Union[torch.Tensor, npt.NDArray] = None,
7474
discrete: Union[torch.Tensor, npt.NDArray] = None,
7575
offset: int = 1,
76+
device: str = "cpu"
7677
):
77-
super().__init__()
78+
super().__init__(device=device)
7879
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
7980
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
8081
self.discrete = self._to_tensor(discrete, torch.LongTensor)

0 commit comments

Comments
 (0)