Skip to content

Commit 0eac868

Browse files
committed
Merge remote-tracking branch 'origin/main' into batched-inference-and-padding
2 parents e1b7cc7 + e652b9a commit 0eac868

28 files changed

+354
-153
lines changed

cebra/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ def __getattr__(key):
9292

9393
return CEBRA
9494
elif key == "KNNDecoder":
95-
from cebra.integrations.sklearn.decoder import KNNDecoder
95+
from cebra.integrations.sklearn.decoder import KNNDecoder # noqa: F811
9696

9797
return KNNDecoder
9898
elif key == "L1LinearRegressor":
99-
from cebra.integrations.sklearn.decoder import L1LinearRegressor
99+
from cebra.integrations.sklearn.decoder import L1LinearRegressor # noqa: F811
100100

101101
return L1LinearRegressor
102102
elif not key.startswith("_"):

cebra/data/datasets.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
"""Pre-defined datasets."""
2323

2424
import types
25-
from typing import List, Tuple, Union
25+
from typing import List, Literal, Optional, Tuple, Union
2626

2727
import numpy as np
2828
import numpy.typing as npt
2929
import torch
3030

3131
import cebra.data as cebra_data
32+
import cebra.helper as cebra_helper
33+
from cebra.data.datatypes import Offset
3234

3335

3436
class TensorDataset(cebra_data.SingleSessionDataset):
@@ -64,26 +66,52 @@ def __init__(self,
6466
neural: Union[torch.Tensor, npt.NDArray],
6567
continuous: Union[torch.Tensor, npt.NDArray] = None,
6668
discrete: Union[torch.Tensor, npt.NDArray] = None,
67-
offset: int = 1,
69+
offset: Offset = Offset(0, 1),
6870
device: str = "cpu"):
6971
super().__init__(device=device)
70-
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
71-
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
72-
self.discrete = self._to_tensor(discrete, torch.LongTensor)
72+
self.neural = self._to_tensor(neural, check_dtype="float").float()
73+
self.continuous = self._to_tensor(continuous, check_dtype="float")
74+
self.discrete = self._to_tensor(discrete, check_dtype="int")
7375
if self.continuous is None and self.discrete is None:
7476
raise ValueError(
7577
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
7678
)
7779
self.offset = offset
7880

79-
def _to_tensor(self, array, check_dtype=None):
81+
def _to_tensor(
82+
self,
83+
array: Union[torch.Tensor, npt.NDArray],
84+
check_dtype: Optional[Literal["int",
85+
"float"]] = None) -> torch.Tensor:
86+
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.
87+
88+
Args:
89+
array: Array to check.
90+
check_dtype: If not `None`, list of dtypes to which the values in `array`
91+
must belong to. Defaults to None.
92+
93+
Returns:
94+
The `array` as a :py:class:`torch.Tensor`.
95+
"""
8096
if array is None:
8197
return None
8298
if isinstance(array, np.ndarray):
8399
array = torch.from_numpy(array)
84100
if check_dtype is not None:
85-
if not isinstance(array, check_dtype):
86-
raise TypeError(f"{type(array)} instead of {check_dtype}.")
101+
if check_dtype not in ["int", "float"]:
102+
raise ValueError(
103+
f"check_dtype must be 'int' or 'float', got {check_dtype}")
104+
if (check_dtype == "int" and not cebra_helper._is_integer(array)
105+
) or (check_dtype == "float" and
106+
not cebra_helper._is_floating(array)):
107+
raise TypeError(
108+
f"Array has type {array.dtype} instead of {check_dtype}.")
109+
if cebra_helper._is_floating(array):
110+
array = array.float()
111+
if cebra_helper._is_integer(array):
112+
# NOTE(stes): Required for standardizing number format on
113+
# windows machines.
114+
array = array.long()
87115
return array
88116

89117
@property

cebra/data/helper.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,15 @@ class OrthogonalProcrustesAlignment:
9494
9595
For each dataset, the data and labels to align the data on is provided.
9696
97-
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to the labels of the reference dataset (``ref_label``) are selected and used to sample from the dataset to align (``data``).
98-
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number of samples ``subsample``.
99-
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`, on those subsampled datasets.
100-
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data`` to the ``ref_data``.
97+
1. The ``top_k`` indexes of the labels to align (``label``) that are the closest to
98+
the labels of the reference dataset (``ref_label``) are selected and used to sample
99+
from the dataset to align (``data``).
100+
2. ``data`` and ``ref_data`` (the reference dataset) are subsampled to the same number
101+
of samples ``subsample``.
102+
3. The orthogonal mapping is computed, using :py:func:`scipy.linalg.orthogonal_procrustes`,
103+
on those subsampled datasets.
104+
4. The resulting orthongonal matrix ``_transform`` can be used to map the original ``data``
105+
to the ``ref_data``.
101106
102107
Note:
103108
``data`` and ``ref_data`` can be of different sample size (axis 0) but **must** have the same number

cebra/data/load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,8 @@ def load(
663663
- if no key is provided, the first data structure found upon iteration of the collection will be loaded;
664664
- if a key is provided, it needs to correspond to an existing item of the collection;
665665
- if a key is provided, the data value accessed needs to be a data structure;
666-
- the function loads data for only one data structure, even if the file contains more. The function can be called again with the corresponding key to get the other ones.
666+
- the function loads data for only one data structure, even if the file contains more. The function can be
667+
called again with the corresponding key to get the other ones.
667668
668669
Args:
669670
file: The path to the given file to load, in a supported format.

cebra/data/single_session.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ def __post_init__(self):
358358
# here might be sub-optimal. The final behavior should be determined after
359359
# e.g. integrating the FAISS dataloader back in.
360360
super().__post_init__()
361-
index = self.index.to(self.device)
362361

363362
if self.conditional != "time_delta":
364363
raise NotImplementedError(
@@ -368,8 +367,7 @@ def __post_init__(self):
368367
self.time_distribution = cebra.distributions.TimeContrastive(
369368
time_offset=self.time_offset,
370369
num_samples=len(self.dataset.neural),
371-
device=self.device,
372-
)
370+
device=self.device)
373371
self.behavior_distribution = cebra.distributions.TimedeltaDistribution(
374372
self.dataset.continuous_index, self.time_offset, device=self.device)
375373

cebra/datasets/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ def get_datapath(path: str = None) -> str:
9898
from cebra.datasets.monkey_reaching import *
9999
from cebra.datasets.synthetic_data import *
100100
except ModuleNotFoundError as e:
101-
import warnings
102-
103101
warnings.warn(f"Could not initialize one or more datasets: {e}. "
104102
f"For using the datasets, consider installing the "
105103
f"[datasets] extension via pip.")

cebra/datasets/allen/ca_movie.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
"""Allen pseudomouse Ca dataset.
2323
2424
References:
25-
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
26-
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
27-
*https://github.com/zivlab/visual_drift
28-
*http://observatory.brain-map.org/visualcoding
29-
25+
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
26+
"Representational drift in the mouse visual cortex."
27+
Current biology 31.19 (2021): 4327-4339.
28+
* de Vries, Saskia EJ, et al.
29+
"A large-scale standardized physiological survey reveals functional
30+
organization of the mouse visual cortex."
31+
Nature neuroscience 23.1 (2020): 138-151.
32+
* https://github.com/zivlab/visual_drift
33+
* http://observatory.brain-map.org/visualcoding
3034
"""
3135

3236
import pathlib

cebra/datasets/allen/ca_movie_decoding.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
"""Allen pseudomouse Ca decoding dataset with train/test split.
2323
2424
References:
25-
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
26-
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
27-
*https://github.com/zivlab/visual_drift
28-
*http://observatory.brain-map.org/visualcoding
29-
25+
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
26+
"Representational drift in the mouse visual cortex."
27+
Current biology 31.19 (2021): 4327-4339.
28+
* de Vries, Saskia EJ, et al.
29+
"A large-scale standardized physiological survey reveals functional
30+
organization of the mouse visual cortex."
31+
Nature neuroscience 23.1 (2020): 138-151.
32+
* https://github.com/zivlab/visual_drift
33+
* http://observatory.brain-map.org/visualcoding
3034
"""
3135

3236
import pathlib
@@ -243,11 +247,6 @@ def _convert_to_nums(string):
243247

244248
return pseudo_mouse
245249

246-
pseudo_mouse = np.vstack(
247-
[get_neural_data(num_movie, mice) for mice in list_mice])
248-
249-
return pseudo_mouse
250-
251250
def __len__(self):
252251
return self.neural.size(0)
253252

cebra/datasets/allen/combined.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,19 @@
2222
"""Joint Allen pseudomouse Ca/Neuropixel datasets.
2323
2424
References:
25-
*Deitch, Daniel, Alon Rubin, and Yaniv Ziv. "Representational drift in the mouse visual cortex." Current biology 31.19 (2021): 4327-4339.
26-
*de Vries, Saskia EJ, et al. "A large-scale standardized physiological survey reveals functional organization of the mouse visual cortex." Nature neuroscience 23.1 (2020): 138-151.
27-
*https://github.com/zivlab/visual_drift
28-
*http://observatory.brain-map.org/visualcoding
29-
*https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
30-
*Siegle, Joshua H., et al. "Survey of spiking in the mouse visual system reveals functional hierarchy." Nature 592.7852 (2021): 86-92.
31-
25+
* Deitch, Daniel, Alon Rubin, and Yaniv Ziv.
26+
"Representational drift in the mouse visual cortex."
27+
Current Biology 31.19 (2021): 4327-4339.
28+
* de Vries, Saskia EJ, et al.
29+
"A large-scale standardized physiological survey reveals functional
30+
organization of the mouse visual cortex."
31+
Nature Neuroscience 23.1 (2020): 138-151.
32+
* https://github.com/zivlab/visual_drift
33+
* http://observatory.brain-map.org/visualcoding
34+
* https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
35+
* Siegle, Joshua H., et al.
36+
"Survey of spiking in the mouse visual system reveals functional hierarchy."
37+
Nature 592.7852 (2021): 86-92.
3238
"""
3339

3440
import cebra.data

cebra/datasets/allen/make_neuropixel.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,12 @@ def read_neuropixel(
192192
"intervals/natural_movie_one_presentations/start_time"][...]
193193
end_time = d[
194194
"intervals/natural_movie_one_presentations/stop_time"][...]
195-
timeseries = d[
196-
"intervals/natural_movie_one_presentations/timeseries"][...]
197-
timeseries_index = d[
198-
"intervals/natural_movie_one_presentations/timeseries_index"][
199-
...]
195+
# NOTE(stes): never used. leaving here for future reference
196+
#timeseries = d[
197+
# "intervals/natural_movie_one_presentations/timeseries"][...]
198+
#timeseries_index = d[
199+
# "intervals/natural_movie_one_presentations/timeseries_index"][
200+
# ...]
200201
session_no = d["identifier"][...].item()
201202
spike_time_index = d["units/spike_times_index"][...]
202203
spike_times = d["units/spike_times"][...]
@@ -266,14 +267,14 @@ def read_neuropixel(
266267
"neural": sessions_dic,
267268
"frames": session_frames
268269
},
269-
Path(args.save_path) /
270+
pathlib.Path(args.save_path) /
270271
f"neuropixel_sessions_{int(args.sampling_rate)}_filtered.jl",
271272
)
272273
jl.dump(
273274
{
274275
"neural": pseudo_mice,
275276
"frames": pseudo_mice_frames
276277
},
277-
Path(args.save_path) /
278+
pathlib.Path(args.save_path) /
278279
f"neuropixel_pseudomouse_{int(args.sampling_rate)}_filtered.jl",
279280
)

0 commit comments

Comments
 (0)