Skip to content

Commit c95bd5a

Browse files
authored
Support More Data Types in the sklearn Dataset and Decoder (#43)
1 parent b707a5a commit c95bd5a

File tree

8 files changed

+131
-79
lines changed

8 files changed

+131
-79
lines changed

cebra/data/helper.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,44 @@
1919
import scipy.linalg
2020
import torch
2121

22+
import cebra.data.base as cebra_data_base
23+
import cebra.data.multi_session as cebra_data_multisession
24+
import cebra.data.single_session as cebra_data_singlesession
25+
26+
27+
def get_loader_options(dataset: cebra_data_base.Dataset) -> List[str]:
28+
"""Return all possible dataloaders for the given dataset."""
29+
30+
loader_options = []
31+
if isinstance(dataset, cebra_data_singlesession.SingleSessionDataset):
32+
mixed = True
33+
if dataset.continuous_index is not None:
34+
loader_options.append(cebra_data_singlesession.ContinuousDataLoader)
35+
else:
36+
mixed = False
37+
if dataset.discrete_index is not None:
38+
loader_options.append(cebra_data_singlesession.DiscreteDataLoader)
39+
else:
40+
mixed = False
41+
if mixed:
42+
loader_options.append(cebra_data_singlesession.MixedDataLoader)
43+
elif isinstance(dataset, cebra_data_multisession.MultiSessionDataset):
44+
mixed = True
45+
if dataset.continuous_index is not None:
46+
loader_options.append(
47+
cebra_data_multisession.ContinuousMultiSessionDataLoader)
48+
else:
49+
mixed = False
50+
if dataset.discrete_index is not None:
51+
pass # not implemented yet
52+
else:
53+
mixed = False
54+
if mixed:
55+
pass # not implemented yet
56+
else:
57+
raise TypeError(f"Invalid dataset type: {dataset}")
58+
return loader_options
59+
2260

2361
def _require_numpy_array(array: Union[npt.NDArray, torch.Tensor]):
2462
if not isinstance(array, np.ndarray):

cebra/distributions/discrete.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020

2121
import cebra.distributions.base as abc_
22+
import cebra.helper
2223

2324

2425
class Discrete(abc_.ConditionalDistribution, abc_.HasGenerator):
@@ -38,7 +39,7 @@ def _to_numpy_int(self, samples: Union[torch.Tensor,
3839
npt.NDArray]) -> npt.NDArray:
3940
if isinstance(samples, torch.Tensor):
4041
samples = samples.cpu().numpy()
41-
if samples.dtype not in (np.int32, np.int64):
42+
if not cebra.helper._is_integer(samples):
4243
samples = samples.astype(int)
4344
return samples
4445

cebra/helper.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,47 +15,18 @@
1515
import pathlib
1616
import tempfile
1717
import urllib
18+
import warnings
1819
import zipfile
19-
from typing import List
20+
from typing import List, Union
2021

22+
import numpy as np
23+
import numpy.typing as npt
2124
import requests
25+
import torch
2226

2327
import cebra.data
2428

2529

26-
def get_loader_options(dataset: cebra.data.Dataset) -> List[str]:
27-
"""Return all possible dataloaders for the given dataset."""
28-
29-
loader_options = []
30-
if isinstance(dataset, cebra.data.SingleSessionDataset):
31-
mixed = True
32-
if dataset.continuous_index is not None:
33-
loader_options.append(cebra.data.ContinuousDataLoader)
34-
else:
35-
mixed = False
36-
if dataset.discrete_index is not None:
37-
loader_options.append(cebra.data.DiscreteDataLoader)
38-
else:
39-
mixed = False
40-
if mixed:
41-
loader_options.append(cebra.data.MixedDataLoader)
42-
elif isinstance(dataset, cebra.data.MultiSessionDataset):
43-
mixed = True
44-
if dataset.continuous_index is not None:
45-
loader_options.append(cebra.data.ContinuousMultiSessionDataLoader)
46-
else:
47-
mixed = False
48-
if dataset.discrete_index is not None:
49-
pass # not implemented yet
50-
else:
51-
mixed = False
52-
if mixed:
53-
pass # not implemented yet
54-
else:
55-
raise TypeError(f"Invalid dataset type: {dataset}")
56-
return loader_options
57-
58-
5930
def download_file_from_url(url: str) -> str:
6031
"""Download a fole from ``url``.
6132
@@ -88,3 +59,53 @@ def download_file_from_zip_url(url, file="montblanc_tracks.h5"):
8859
except zipfile.error:
8960
pass
9061
return pathlib.Path(foldername) / "data" / file
62+
63+
64+
def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
65+
"""Check if the values in ``y`` are :py:class:`int`.
66+
67+
Args:
68+
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.
69+
70+
Returns:
71+
``True`` if ``y`` contains :py:class:`int`.
72+
"""
73+
return (isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.integer)
74+
) or (isinstance(y, torch.Tensor) and
75+
(not torch.is_floating_point(y) and not torch.is_complex(y)))
76+
77+
78+
def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool:
79+
"""Check if the values in ``y`` are :py:class:`int`.
80+
81+
Note:
82+
There is no ``torch`` method to check that the ``dtype`` of a :py:class:`torch.Tensor`
83+
is a :py:class:`float`, consequently, we check that it is not :py:class:`int` nor
84+
:py:class:`complex`.
85+
86+
Args:
87+
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.
88+
89+
Returns:
90+
``True`` if ``y`` contains :py:class:`float`.
91+
"""
92+
93+
return (isinstance(y, np.ndarray) and
94+
np.issubdtype(y.dtype, np.floating)) or (isinstance(
95+
y, torch.Tensor) and torch.is_floating_point(y))
96+
97+
98+
def get_loader_options(dataset: "cebra.data.Dataset") -> List[str]:
99+
"""Return all possible dataloaders for the given dataset.
100+
101+
Notes:
102+
This function is deprecated and will be removed in an upcoming version of CEBRA.
103+
Please use :py:mod:`cebra.data.helper.get_loader_options` instead, which is an
104+
exact copy.
105+
"""
106+
107+
import cebra.data.helper
108+
warnings.warn(
109+
"The 'get_loader_options' function has been moved to 'cebra.data.helpers' module. "
110+
"Please update your imports.", DeprecationWarning)
111+
return cebra.data.helper.get_loader_options

cebra/integrations/sklearn/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
import cebra.data
21+
import cebra.helper
2122
import cebra.integrations.sklearn.utils as cebra_sklearn_utils
2223
import cebra.models
2324
import cebra.solver
@@ -134,12 +135,12 @@ def _parse_labels(self, labels: Optional[tuple]):
134135

135136
# Define the index as either continuous or discrete indices, depending
136137
# on the dtype in the index array.
137-
if y.dtype in (np.float32, np.float64):
138+
if cebra.helper._is_floating(y):
138139
y = torch.from_numpy(y).float()
139140
if y.dim() == 1:
140141
y = y.unsqueeze(1)
141142
continuous_index.append(y)
142-
elif y.dtype in (np.int32, np.int64):
143+
elif cebra.helper._is_integer(y):
143144
y = torch.from_numpy(y).long().squeeze()
144145
if y.dim() > 1:
145146
raise ValueError(

cebra/integrations/sklearn/decoder.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,7 @@
2121
import sklearn.neighbors
2222
import torch
2323

24-
25-
def _is_integer(y: Union[npt.NDArray, torch.Tensor]) -> bool:
26-
"""Check if the values in ``y`` are :py:class:`int`.
27-
28-
Args:
29-
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.
30-
31-
Returns:
32-
``True`` if ``y`` contains :py:class:`int`.
33-
"""
34-
return (isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.integer)
35-
) or (isinstance(y, torch.Tensor) and
36-
(not torch.is_floating_point(y) and not torch.is_complex(y)))
37-
38-
39-
def _is_floating(y: Union[npt.NDArray, torch.Tensor]) -> bool:
40-
"""Check if the values in ``y`` are :py:class:`int`.
41-
42-
Note:
43-
There is no ``torch`` method to check that the ``dtype`` of a :py:class:`torch.Tensor`
44-
is a :py:class:`float`, consequently, we check that it is not :py:class:`int` nor
45-
:py:class:`complex`.
46-
47-
Args:
48-
y: An array, either as a :py:func:`numpy.array` or a :py:class:`torch.Tensor`.
49-
50-
Returns:
51-
``True`` if ``y`` contains :py:class:`float`.
52-
"""
53-
54-
return (isinstance(y, np.ndarray) and
55-
np.issubdtype(y.dtype, np.floating)) or (isinstance(
56-
y, torch.Tensor) and torch.is_floating_point(y))
24+
import cebra.helper
5725

5826

5927
class Decoder(abc.ABC, sklearn.base.BaseEstimator):
@@ -152,10 +120,10 @@ def fit(
152120
)
153121

154122
# Use regression or classification, based on if the targets are continuous or discrete
155-
if _is_floating(y):
123+
if cebra.helper._is_floating(y):
156124
self.knn = sklearn.neighbors.KNeighborsRegressor(
157125
n_neighbors=self.n_neighbors, metric=self.metric)
158-
elif _is_integer(y):
126+
elif cebra.helper._is_integer(y):
159127
self.knn = sklearn.neighbors.KNeighborsClassifier(
160128
n_neighbors=self.n_neighbors, metric=self.metric)
161129
else:
@@ -237,7 +205,7 @@ def fit(
237205
f"Invalid shape: y and X must have the same number of samples, got y:{len(y)} and X:{len(X)}."
238206
)
239207

240-
if not (_is_integer(y) or _is_floating(y)):
208+
if not (cebra.helper._is_integer(y) or cebra.helper._is_floating(y)):
241209
raise NotImplementedError(
242210
f"Invalid type: targets must be numeric, got y:{y.dtype}")
243211

tests/test_integration_train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
1111
#
1212
import itertools
13+
from typing import List
1314

1415
import pytest
1516
import torch
@@ -18,6 +19,7 @@
1819
import cebra
1920
import cebra.config
2021
import cebra.data
22+
import cebra.data.helper as cebra_data_helper
2123
import cebra.datasets
2224
import cebra.helper
2325
import cebra.models
@@ -68,7 +70,7 @@ def _list_data_loaders():
6870
]
6971
# TODO limit this to the valid combinations---however this
7072
# requires to adapt the dataset API slightly; it is currently
71-
# required to initialize the dataset to run cebra.helper.get_loader_options.
73+
# required to initialize the dataset to run cebra_data_helper.get_loader_options.
7274
prefixes = set()
7375
for dataset_name, loader in itertools.product(cebra.datasets.get_options(),
7476
loaders):
@@ -86,7 +88,7 @@ def test_train(dataset_name, loader_type):
8688
args = cebra.config.Config(num_steps=1, device="cuda").as_namespace()
8789

8890
dataset = cebra.datasets.init(dataset_name)
89-
if loader_type not in cebra.helper.get_loader_options(dataset):
91+
if loader_type not in cebra_data_helper.get_loader_options(dataset):
9092
# skip this test, since the data/loader combination is not valid.
9193
pytest.skip("Not a valid dataset/loader combination.")
9294
loader = loader_type(

tests/test_sklearn.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,26 @@ def test_sklearn_dataset():
117117
cebra_data.datasets.DatasetCollection(*sessions)
118118

119119

120+
@pytest.mark.parametrize("int_type", [np.uint8, np.int8, np.int32])
121+
@pytest.mark.parametrize("float_type", [np.float16, np.float32, np.float64])
122+
def test_sklearn_dataset_type_index(int_type, float_type):
123+
N = 100
124+
X = np.random.uniform(0, 1, (N * 2, 2))
125+
y = np.concatenate([np.zeros(N), np.ones(N)])
126+
127+
# integer type
128+
y = y.astype(int_type)
129+
_, _, loader, _ = cebra.CEBRA(batch_size=512)._prepare_fit(X, y)
130+
assert loader.dataset.discrete_index is not None
131+
assert loader.dataset.continuous_index is None
132+
133+
# floating type
134+
y = y.astype(float_type)
135+
_, _, loader, _ = cebra.CEBRA(batch_size=512)._prepare_fit(X, y)
136+
assert loader.dataset.continuous_index is not None
137+
assert loader.dataset.discrete_index is None
138+
139+
120140
@_util.parametrize_slow(
121141
arg_names="is_cont,is_disc,is_full,is_multi,is_hybrid",
122142
fast_arguments=list(

tests/test_sklearn_decoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414
import torch
1515

16+
import cebra.helper
1617
import cebra.integrations.sklearn.decoder as cebra_sklearn_decoder
1718

1819

@@ -104,7 +105,7 @@ def test_sklearn_decoder(decoder):
104105

105106

106107
def test_dtype_checker():
107-
assert cebra_sklearn_decoder._is_floating(torch.Tensor([4.5]))
108-
assert cebra_sklearn_decoder._is_integer(torch.LongTensor([4]))
109-
assert cebra_sklearn_decoder._is_floating(np.array([4.5]))
110-
assert cebra_sklearn_decoder._is_integer(np.array([4]))
108+
assert cebra.helper._is_floating(torch.Tensor([4.5]))
109+
assert cebra.helper._is_integer(torch.LongTensor([4]))
110+
assert cebra.helper._is_floating(np.array([4.5]))
111+
assert cebra.helper._is_integer(np.array([4]))

0 commit comments

Comments
 (0)