Skip to content

Commit 6008052

Browse files
authored
Fix linting issues (#185)
* Apply auto-fixes * Fix issues in allen datasets - lines too long - unused variables - missing imports - duplicate names for classes * Fix line length+minor issues in plot integr. * Fix linting issues in sklearn integration * Fix linting issues in datasets - missing paths - long lines - unused variables - typos * Fix minor linting issues * Fix docstrings * fix formatting issue in docstring * Fix plotly docstrings * Fix missing import
1 parent 0e2312a commit 6008052

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+251
-305
lines changed

cebra/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from cebra.integrations.sklearn.decoder import L1LinearRegressor
3434

3535
is_sklearn_available = True
36-
except ImportError as e:
36+
except ImportError:
3737
# silently fail for now
3838
pass
3939

@@ -42,7 +42,7 @@
4242
from cebra.integrations.matplotlib import *
4343

4444
is_matplotlib_available = True
45-
except ImportError as e:
45+
except ImportError:
4646
# silently fail for now
4747
pass
4848

@@ -51,7 +51,7 @@
5151
from cebra.integrations.plotly import *
5252

5353
is_plotly_available = True
54-
except ImportError as e:
54+
except ImportError:
5555
# silently fail for now
5656
pass
5757

@@ -92,11 +92,11 @@ def __getattr__(key):
9292

9393
return CEBRA
9494
elif key == "KNNDecoder":
95-
from cebra.integrations.sklearn.decoder import KNNDecoder
95+
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811
9696

9797
return KNNDecoder
9898
elif key == "L1LinearRegressor":
99-
from cebra.integrations.sklearn.decoder import L1LinearRegressor
99+
from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811
100100

101101
return L1LinearRegressor
102102
elif not key.startswith("_"):

cebra/__main__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@
2727
import argparse
2828
import sys
2929

30-
import numpy as np
31-
import torch
32-
3330
import cebra
34-
import cebra.distributions as cebra_distr
3531

3632

3733
def train(parser, kwargs):

cebra/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#
2222
import argparse
2323
import json
24-
from dataclasses import MISSING
2524
from typing import Literal, Optional
2625

2726
import literate_dataclasses as dataclasses

cebra/data/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,8 @@
2222
"""Base classes for datasets and loaders."""
2323

2424
import abc
25-
import collections
26-
from typing import List
2725

2826
import literate_dataclasses as dataclasses
29-
import numpy as np
3027
import torch
3128

3229
import cebra.data.assets as cebra_data_assets

cebra/data/datasets.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,15 @@
2121
#
2222
"""Pre-defined datasets."""
2323

24-
import abc
25-
import collections
2624
import types
2725
from typing import List, Literal, Optional, Tuple, Union
2826

29-
import literate_dataclasses as dataclasses
3027
import numpy as np
3128
import numpy.typing as npt
3229
import torch
33-
from numpy.typing import NDArray
3430

3531
import cebra.data as cebra_data
36-
import cebra.distributions
37-
from cebra.data.datatypes import Batch
38-
from cebra.data.datatypes import BatchIndex
32+
import cebra.helper as cebra_helper
3933

4034

4135
class TensorDataset(cebra_data.SingleSessionDataset):
@@ -75,8 +69,7 @@ def __init__(self,
7569
device: str = "cpu"):
7670
super().__init__(device=device)
7771
self.neural = self._to_tensor(neural, check_dtype="float").float()
78-
self.continuous = self._to_tensor(continuous,
79-
check_dtype="float")
72+
self.continuous = self._to_tensor(continuous, check_dtype="float")
8073
self.discrete = self._to_tensor(discrete, check_dtype="integer")
8174
if self.continuous is None and self.discrete is None:
8275
raise ValueError(
@@ -104,10 +97,11 @@ def _to_tensor(
10497
if isinstance(array, np.ndarray):
10598
array = torch.from_numpy(array)
10699
if check_dtype is not None:
107-
if (check_dtype == "int" and not cebra.helper._is_integer(array)
100+
if (check_dtype == "int" and not cebra_helper._is_integer(array)
108101
) or (check_dtype == "float" and
109-
not cebra.helper._is_floating(array)):
110-
raise TypeError(f"Array has type {array.dtype} instead of {check_dtype}.")
102+
not cebra_helper._is_floating(array)):
103+
raise TypeError(
104+
f"Array has type {array.dtype} instead of {check_dtype}.")
111105
return array
112106

113107
@property

cebra/data/datatypes.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
# limitations under the License.
2121
#
2222
import collections
23-
from typing import Tuple
24-
25-
import torch
2623

2724
__all__ = ["Batch", "BatchIndex", "Offset"]
2825

cebra/data/helper.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment:
9494
9595
For each dataset, the data and labels to align the data on is provided.
9696
97-
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``).
98-
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``.
99-
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets.
100-
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``.
97+
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to
98+
the labels of the reference dataset (``ref_label``) are selected and used to sample
99+
from the dataset to align (``data``).
100+
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number
101+
of samples ``subsample``.
102+
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`,
103+
on those subsampled datasets.
104+
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data``
105+
to the ``ref_data``.
101106
102107
Note:
103108
``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number
@@ -181,14 +186,14 @@ def fit(
181186
elif ref_data.shape[0] == data.shape[0] and (ref_label is None or
182187
label is None):
183188
raise ValueError(
184-
f"Missing labels: the data to align are the same shape but you provided only "
185-
f"one of the sets of labels. Either provide both the reference and alignment "
186-
f"labels or none.")
189+
"Missing labels: the data to align are the same shape but you provided only "
190+
"one of the sets of labels. Either provide both the reference and alignment "
191+
"labels or none.")
187192
else:
188193
if ref_label is None or label is None:
189194
raise ValueError(
190-
f"Missing labels: the data to align are not the same shape, "
191-
f"provide labels to align the data and reference data.")
195+
"Missing labels: the data to align are not the same shape, "
196+
"provide labels to align the data and reference data.")
192197

193198
if len(ref_label.shape) == 1:
194199
ref_label = np.expand_dims(ref_label, axis=1)

cebra/data/load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,8 @@ def load(
663663
- if no key is provided, the first data structure found upon iteration of the collection will be loaded;
664664
- if a key is provided, it needs to correspond to an existing item of the collection;
665665
- if a key is provided, the data value accessed needs to be a data structure;
666-
- the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones.
666+
- the function loads data for only one data structure, even if the file contains more. The function can be
667+
called again with the corresponding key to get the other ones.
667668
668669
Args:
669670
file: The path to the given file to load, in a supported format.

cebra/data/multi_session.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
"""Datasets and loaders for multi-session training."""
2323

2424
import abc
25-
import collections
2625
from typing import List
2726

2827
import literate_dataclasses as dataclasses
29-
import numpy as np
3028
import torch
3129

3230
import cebra.data as cebra_data

cebra/data/single_session.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,9 @@
2626
"""
2727

2828
import abc
29-
import collections
3029
import warnings
31-
from typing import List
3230

3331
import literate_dataclasses as dataclasses
34-
import numpy as np
3532
import torch
3633

3734
import cebra.data as cebra_data
@@ -353,18 +350,16 @@ def __post_init__(self):
353350
# here might be sub-optimal. The final behavior should be determined after
354351
# e.g. integrating the FAISS dataloader back in.
355352
super().__post_init__()
356-
index = self.index.to(self.device)
357353

358354
if self.conditional != "time_delta":
359355
raise NotImplementedError(
360-
f"Hybrid training is currently only implemented using the ``time_delta`` "
361-
f"continual distribution.")
356+
"Hybrid training is currently only implemented using the ``time_delta`` "
357+
"continual distribution.")
362358

363359
self.time_distribution = cebra.distributions.TimeContrastive(
364360
time_offset=self.time_offset,
365361
num_samples=len(self.dataset.neural),
366-
device=self.device,
367-
)
362+
device=self.device)
368363
self.behavior_distribution = cebra.distributions.TimedeltaDistribution(
369364
self.dataset.continuous_index, self.time_offset, device=self.device)
370365

0 commit comments

Comments
 (0)