Skip to content

Commit 74fc232

Browse files
authored
Merge branch 'main' into icarosadero-patch-2
2 parents aaf912f + 5f46c32 commit 74fc232

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+521
-361
lines changed

.github/workflows/build.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@ jobs:
1919
# as well as selected previous versions on
2020
# https://pytorch.org/get-started/previous-versions/
2121
torch-version: ["2.2.2", "2.4.0"]
22+
sklearn-version: ["latest"]
2223
include:
2324
- os: windows-latest
2425
torch-version: 2.4.0
2526
python-version: "3.10"
27+
sklearn-version: "latest"
28+
- os: ubuntu-latest
29+
torch-version: 2.4.0
30+
python-version: "3.10"
31+
sklearn-version: "legacy"
2632

2733
runs-on: ${{ matrix.os }}
2834

@@ -32,7 +38,7 @@ jobs:
3238
uses: actions/cache@v3
3339
with:
3440
path: ~/.cache/pip
35-
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}
41+
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }}
3642

3743
- name: Checkout code
3844
uses: actions/checkout@v2
@@ -48,6 +54,11 @@ jobs:
4854
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
4955
pip install '.[dev,datasets,integrations]'
5056
57+
- name: Check sklearn legacy version
58+
if: matrix.sklearn-version == 'legacy'
59+
run: |
60+
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
61+
5162
- name: Run the formatter
5263
run: |
5364
make format

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ repos:
2020
- id: isort
2121
additional_dependencies:
2222
- pyproject.toml
23+
- repo: https://github.com/astral-sh/ruff-pre-commit
24+
rev: v0.0.280
25+
hooks:
26+
- id: ruff

cebra/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from cebra.integrations.sklearn.decoder import L1LinearRegressor
3434

3535
is_sklearn_available = True
36-
except ImportError as e:
36+
except ImportError:
3737
# silently fail for now
3838
pass
3939

@@ -42,7 +42,7 @@
4242
from cebra.integrations.matplotlib import *
4343

4444
is_matplotlib_available = True
45-
except ImportError as e:
45+
except ImportError:
4646
# silently fail for now
4747
pass
4848

@@ -51,7 +51,7 @@
5151
from cebra.integrations.plotly import *
5252

5353
is_plotly_available = True
54-
except ImportError as e:
54+
except ImportError:
5555
# silently fail for now
5656
pass
5757

@@ -92,11 +92,11 @@ def __getattr__(key):
9292

9393
return CEBRA
9494
elif key == "KNNDecoder":
95-
from cebra.integrations.sklearn.decoder import KNNDecoder
95+
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811
9696

9797
return KNNDecoder
9898
elif key == "L1LinearRegressor":
99-
from cebra.integrations.sklearn.decoder import L1LinearRegressor
99+
from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811
100100

101101
return L1LinearRegressor
102102
elif not key.startswith("_"):

cebra/__main__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@
2727
import argparse
2828
import sys
2929

30-
import numpy as np
31-
import torch
32-
3330
import cebra
34-
import cebra.distributions as cebra_distr
3531

3632

3733
def train(parser, kwargs):

cebra/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#
2222
import argparse
2323
import json
24-
from dataclasses import MISSING
2524
from typing import Literal, Optional
2625

2726
import literate_dataclasses as dataclasses

cebra/data/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,8 @@
2222
"""Base classes for datasets and loaders."""
2323

2424
import abc
25-
import collections
26-
from typing import List
2725

2826
import literate_dataclasses as dataclasses
29-
import numpy as np
3027
import torch
3128

3229
import cebra.data.assets as cebra_data_assets

cebra/data/datasets.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,16 @@
2121
#
2222
"""Pre-defined datasets."""
2323

24-
import abc
25-
import collections
2624
import types
27-
from typing import List, Tuple, Union
25+
from typing import List, Literal, Optional, Tuple, Union
2826

29-
import literate_dataclasses as dataclasses
3027
import numpy as np
3128
import numpy.typing as npt
3229
import torch
33-
from numpy.typing import NDArray
3430

3531
import cebra.data as cebra_data
36-
import cebra.distributions
37-
from cebra.data.datatypes import Batch
38-
from cebra.data.datatypes import BatchIndex
32+
import cebra.helper as cebra_helper
33+
from cebra.data.datatypes import Offset
3934

4035

4136
class TensorDataset(cebra_data.SingleSessionDataset):
@@ -71,26 +66,52 @@ def __init__(self,
7166
neural: Union[torch.Tensor, npt.NDArray],
7267
continuous: Union[torch.Tensor, npt.NDArray] = None,
7368
discrete: Union[torch.Tensor, npt.NDArray] = None,
74-
offset: int = 1,
69+
offset: Offset = Offset(0, 1),
7570
device: str = "cpu"):
7671
super().__init__(device=device)
77-
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
78-
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
79-
self.discrete = self._to_tensor(discrete, torch.LongTensor)
72+
self.neural = self._to_tensor(neural, check_dtype="float").float()
73+
self.continuous = self._to_tensor(continuous, check_dtype="float")
74+
self.discrete = self._to_tensor(discrete, check_dtype="int")
8075
if self.continuous is None and self.discrete is None:
8176
raise ValueError(
8277
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
8378
)
8479
self.offset = offset
8580

86-
def _to_tensor(self, array, check_dtype=None):
81+
def _to_tensor(
82+
self,
83+
array: Union[torch.Tensor, npt.NDArray],
84+
check_dtype: Optional[Literal["int",
85+
"float"]] = None) -> torch.Tensor:
86+
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.
87+
88+
Args:
89+
array: Array to check.
90+
check_dtype: If not `None`, list of dtypes to which the values in `array`
91+
must belong to. Defaults to None.
92+
93+
Returns:
94+
The `array` as a :py:class:`torch.Tensor`.
95+
"""
8796
if array is None:
8897
return None
8998
if isinstance(array, np.ndarray):
9099
array = torch.from_numpy(array)
91100
if check_dtype is not None:
92-
if not isinstance(array, check_dtype):
93-
raise TypeError(f"{type(array)} instead of {check_dtype}.")
101+
if check_dtype not in ["int", "float"]:
102+
raise ValueError(
103+
f"check_dtype must be 'int' or 'float', got {check_dtype}")
104+
if (check_dtype == "int" and not cebra_helper._is_integer(array)
105+
) or (check_dtype == "float" and
106+
not cebra_helper._is_floating(array)):
107+
raise TypeError(
108+
f"Array has type {array.dtype} instead of {check_dtype}.")
109+
if cebra_helper._is_floating(array):
110+
array = array.float()
111+
if cebra_helper._is_integer(array):
112+
# NOTE(stes): Required for standardizing number format on
113+
# windows machines.
114+
array = array.long()
94115
return array
95116

96117
@property

cebra/data/datatypes.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
# limitations under the License.
2121
#
2222
import collections
23-
from typing import Tuple
24-
25-
import torch
2623

2724
__all__ = ["Batch", "BatchIndex", "Offset"]
2825

cebra/data/helper.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment:
9494
9595
For each dataset, the data and labels to align the data on is provided.
9696
97-
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``).
98-
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``.
99-
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets.
100-
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``.
97+
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to
98+
the labels of the reference dataset (``ref_label``) are selected and used to sample
99+
from the dataset to align (``data``).
100+
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number
101+
of samples ``subsample``.
102+
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`,
103+
on those subsampled datasets.
104+
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data``
105+
to the ``ref_data``.
101106
102107
Note:
103108
``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number
@@ -181,14 +186,14 @@ def fit(
181186
elif ref_data.shape[0] == data.shape[0] and (ref_label is None or
182187
label is None):
183188
raise ValueError(
184-
f"Missing labels: the data to align are the same shape but you provided only "
185-
f"one of the sets of labels. Either provide both the reference and alignment "
186-
f"labels or none.")
189+
"Missing labels: the data to align are the same shape but you provided only "
190+
"one of the sets of labels. Either provide both the reference and alignment "
191+
"labels or none.")
187192
else:
188193
if ref_label is None or label is None:
189194
raise ValueError(
190-
f"Missing labels: the data to align are not the same shape, "
191-
f"provide labels to align the data and reference data.")
195+
"Missing labels: the data to align are not the same shape, "
196+
"provide labels to align the data and reference data.")
192197

193198
if len(ref_label.shape) == 1:
194199
ref_label = np.expand_dims(ref_label, axis=1)

cebra/data/load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,8 @@ def load(
663663
- if no key is provided, the first data structure found upon iteration of the collection will be loaded;
664664
- if a key is provided, it needs to correspond to an existing item of the collection;
665665
- if a key is provided, the data value accessed needs to be a data structure;
666-
- the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones.
666+
- the function loads data for only one data structure, even if the file contains more. The function can be
667+
called again with the corresponding key to get the other ones.
667668
668669
Args:
669670
file: The path to the given file to load, in a supported format.

0 commit comments

Comments
 (0)