|
21 | 21 | # |
22 | 22 | """Pre-defined datasets.""" |
23 | 23 |
|
24 | | -import abc |
25 | | -import collections |
26 | 24 | import types |
27 | | -from typing import List, Tuple, Union |
| 25 | +from typing import List, Literal, Optional, Tuple, Union |
28 | 26 |
|
29 | | -import literate_dataclasses as dataclasses |
30 | 27 | import numpy as np |
31 | 28 | import numpy.typing as npt |
32 | 29 | import torch |
33 | | -from numpy.typing import NDArray |
34 | 30 |
|
35 | 31 | import cebra.data as cebra_data |
36 | | -import cebra.distributions |
37 | | -from cebra.data.datatypes import Batch |
38 | | -from cebra.data.datatypes import BatchIndex |
| 32 | +import cebra.helper as cebra_helper |
| 33 | +from cebra.data.datatypes import Offset |
39 | 34 |
|
40 | 35 |
|
41 | 36 | class TensorDataset(cebra_data.SingleSessionDataset): |
@@ -71,26 +66,52 @@ def __init__(self, |
71 | 66 | neural: Union[torch.Tensor, npt.NDArray], |
72 | 67 | continuous: Union[torch.Tensor, npt.NDArray] = None, |
73 | 68 | discrete: Union[torch.Tensor, npt.NDArray] = None, |
74 | | - offset: int = 1, |
| 69 | + offset: Offset = Offset(0, 1), |
75 | 70 | device: str = "cpu"): |
76 | 71 | super().__init__(device=device) |
77 | | - self.neural = self._to_tensor(neural, torch.FloatTensor).float() |
78 | | - self.continuous = self._to_tensor(continuous, torch.FloatTensor) |
79 | | - self.discrete = self._to_tensor(discrete, torch.LongTensor) |
| 72 | + self.neural = self._to_tensor(neural, check_dtype="float").float() |
| 73 | + self.continuous = self._to_tensor(continuous, check_dtype="float") |
| 74 | + self.discrete = self._to_tensor(discrete, check_dtype="int") |
80 | 75 | if self.continuous is None and self.discrete is None: |
81 | 76 | raise ValueError( |
82 | 77 | "You have to pass at least one of the arguments 'continuous' or 'discrete'." |
83 | 78 | ) |
84 | 79 | self.offset = offset |
85 | 80 |
|
86 | | - def _to_tensor(self, array, check_dtype=None): |
| 81 | + def _to_tensor( |
| 82 | + self, |
| 83 | + array: Union[torch.Tensor, npt.NDArray], |
| 84 | + check_dtype: Optional[Literal["int", |
| 85 | + "float"]] = None) -> torch.Tensor: |
| 86 | + """Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype. |
| 87 | +
|
| 88 | + Args: |
| 89 | + array: Array to check. |
| 90 | + check_dtype: If not `None`, list of dtypes to which the values in `array` |
| 91 | + must belong to. Defaults to None. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + The `array` as a :py:class:`torch.Tensor`. |
| 95 | + """ |
87 | 96 | if array is None: |
88 | 97 | return None |
89 | 98 | if isinstance(array, np.ndarray): |
90 | 99 | array = torch.from_numpy(array) |
91 | 100 | if check_dtype is not None: |
92 | | - if not isinstance(array, check_dtype): |
93 | | - raise TypeError(f"{type(array)} instead of {check_dtype}.") |
| 101 | + if check_dtype not in ["int", "float"]: |
| 102 | + raise ValueError( |
| 103 | + f"check_dtype must be 'int' or 'float', got {check_dtype}") |
| 104 | + if (check_dtype == "int" and not cebra_helper._is_integer(array) |
| 105 | + ) or (check_dtype == "float" and |
| 106 | + not cebra_helper._is_floating(array)): |
| 107 | + raise TypeError( |
| 108 | + f"Array has type {array.dtype} instead of {check_dtype}.") |
| 109 | + if cebra_helper._is_floating(array): |
| 110 | + array = array.float() |
| 111 | + if cebra_helper._is_integer(array): |
| 112 | + # NOTE(stes): Required for standardizing number format on |
| 113 | + # windows machines. |
| 114 | + array = array.long() |
94 | 115 | return array |
95 | 116 |
|
96 | 117 | @property |
|
0 commit comments