Skip to content

Commit 0e2312a

Browse files
CeliaBenquetstesMMathisLab
authored
Extend type checking to all float datatypes (#166)
* Make the type checking less sensitive * Use existing type checking functions * Add better typing * Update cebra/data/datasets.py * Update cebra/data/datasets.py --------- Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent 568c5f8 commit 0e2312a

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

cebra/data/datasets.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import abc
2525
import collections
2626
import types
27-
from typing import List, Tuple, Union
27+
from typing import List, Literal, Optional, Tuple, Union
2828

2929
import literate_dataclasses as dataclasses
3030
import numpy as np
@@ -74,23 +74,40 @@ def __init__(self,
7474
offset: int = 1,
7575
device: str = "cpu"):
7676
super().__init__(device=device)
77-
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
78-
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
79-
self.discrete = self._to_tensor(discrete, torch.LongTensor)
77+
self.neural = self._to_tensor(neural, check_dtype="float").float()
78+
self.continuous = self._to_tensor(continuous,
79+
check_dtype="float")
80+
self.discrete = self._to_tensor(discrete, check_dtype="integer")
8081
if self.continuous is None and self.discrete is None:
8182
raise ValueError(
8283
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
8384
)
8485
self.offset = offset
8586

86-
def _to_tensor(self, array, check_dtype=None):
87+
def _to_tensor(
88+
self,
89+
array: Union[torch.Tensor, npt.NDArray],
90+
check_dtype: Optional[Literal["int",
91+
"float"]] = None) -> torch.Tensor:
92+
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.
93+
94+
Args:
95+
array: Array to check.
96+
check_dtype (list, optional): If not `None`, list of dtypes to which the values in `array`
97+
must belong to. Defaults to None.
98+
99+
Returns:
100+
The `array` as a :py:class:`torch.Tensor`.
101+
"""
87102
if array is None:
88103
return None
89104
if isinstance(array, np.ndarray):
90105
array = torch.from_numpy(array)
91106
if check_dtype is not None:
92-
if not isinstance(array, check_dtype):
93-
raise TypeError(f"{type(array)} instead of {check_dtype}.")
107+
if (check_dtype == "int" and not cebra.helper._is_integer(array)
108+
) or (check_dtype == "float" and
109+
not cebra.helper._is_floating(array)):
110+
raise TypeError(f"Array has type {array.dtype} instead of {check_dtype}.")
94111
return array
95112

96113
@property

cebra/helper.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
9999

100100

101101
def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool:
102-
"""Check if the values in ``y`` are :py:class:`int`.
102+
"""Check if the values in ``y`` are :py:class:`float`.
103103
104104
Note:
105105
There is no ``torch`` method to check that the ``dtype`` of a :py:class:`torch.Tensor`
@@ -118,6 +118,19 @@ def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool:
118118
y, torch.Tensor) and torch.is_floating_point(y))
119119

120120

121+
def _is_floating_or_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
122+
"""Check if the values in ``y`` are :py:class:`int` or :py:class:`float`.
123+
124+
Args:
125+
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.
126+
127+
Returns:
128+
``True`` if ``y`` contains :py:class:`float` or :py:class:`int`.
129+
"""
130+
131+
return _is_floating(y) or _is_integer(y)
132+
133+
121134
def get_loader_options(dataset: "cebra.data.Dataset") -> List[str]:
122135
"""Return all possible dataloaders for the given dataset.
123136

0 commit comments

Comments
 (0)