We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4745b20 commit 3b61114Copy full SHA for 3b61114
cebra/data/datasets.py
@@ -73,8 +73,9 @@ def __init__(
73
continuous: Union[torch.Tensor, npt.NDArray] = None,
74
discrete: Union[torch.Tensor, npt.NDArray] = None,
75
offset: int = 1,
76
+ device: str = "cpu"
77
):
- super().__init__()
78
+ super().__init__(device=device)
79
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
80
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
81
self.discrete = self._to_tensor(discrete, torch.LongTensor)
0 commit comments