Skip to content

Commit ae056b2

Browse files
authored
Merge branch 'main' into batched-inference-and-padding
2 parents 0c693dd + 9e14790 commit ae056b2

35 files changed

+475
-176
lines changed

.github/workflows/build.yml

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,15 @@ jobs:
1414
fail-fast: true
1515
matrix:
1616
os: [ubuntu-latest]
17-
python-version: ["3.8", "3.10"]
17+
python-version: ["3.9", "3.10", "3.12"]
1818
# We aim to support the versions on pytorch.org
1919
# as well as selected previous versions on
2020
# https://pytorch.org/get-started/previous-versions/
21-
torch-version: ["1.12.1", "2.0.0"]
21+
torch-version: ["2.2.2", "2.4.0"]
2222
include:
23-
- os: ubuntu-latest
24-
python-version: 3.8
25-
torch-version: 1.9.0
2623
- os: windows-latest
27-
torch-version: 2.0.0
24+
torch-version: 2.4.0
2825
python-version: "3.10"
29-
- os: ubuntu-latest
30-
torch-version: 2.1.1
31-
python-version: "3.11"
32-
#- os: macos-latest
33-
# torch-version: 2.0.0
34-
# python-version: "3.10"
3526

3627
runs-on: ${{ matrix.os }}
3728

CODE_OF_CONDUCT.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ In the interest of fostering an open and welcoming environment, we as
66
contributors and maintainers pledge to making participation in our project and
77
our community a harassment-free experience for everyone, regardless of age, body
88
size, disability, ethnicity, sex characteristics, gender identity and expression,
9-
level of experience, education, socio-economic status, nationality, personal
9+
level of experience, education, socioeconomic status, nationality, personal
1010
appearance, race, religion, or sexual identity and orientation.
1111

1212
## Our Standards

Dockerfile

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
## EXPERIMENT BASE CONTAINER
2-
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04 AS cebra-base
2+
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 AS cebra-base
33

44
ENV DEBIAN_FRONTEND=noninteractive
55
RUN apt-get update -y \
66
&& apt-get install --no-install-recommends -yy git python3 python3-pip python-is-python3 \
77
&& rm -rf /var/lib/apt/lists/*
88

9-
RUN pip install --no-cache-dir torch==2.0.0+cu117 \
10-
--index-url https://download.pytorch.org/whl/cu117
11-
RUN pip install --no-cache-dir --pre 'cebra[dev,datasets,integrations]' \
12-
&& pip uninstall -y cebra
9+
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cu124
10+
RUN pip install --upgrade pip
1311

1412

1513
## GIT repository

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ test: clean_test
2424
doctest: clean_test
2525
python -m pytest --ff --doctest-modules -vvv ./docs/source/usage.rst
2626

27+
docker:
28+
./tools/build_docker.sh
29+
2730
test_parallel: clean_test
2831
python -m pytest -n auto --ff -m "not requires_dataset" tests
2932

@@ -98,4 +101,7 @@ report: check_docker format .coverage .pylint
98101
cat .pylint
99102
coverage report
100103

101-
.PHONY: dist build archlinux clean_test test doctest test_parallel test_parallel_debug test_all test_fast test_debug test_benchmark interrogate docs docs-touch docs-strict serve_docs serve_page format codespell check_for_binary
104+
.PHONY: dist build docker archlinux clean_test test doctest test_parallel \
105+
test_parallel_debug test_all test_fast test_debug test_benchmark \
106+
interrogate docs docs-touch docs-strict serve_docs serve_page \
107+
format codespell check_for_binary

cebra/data/assets.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
#
2222

2323
import hashlib
24-
import os
2524
import re
2625
import warnings
26+
from pathlib import Path
2727
from typing import Optional
2828

2929
import requests
@@ -57,8 +57,10 @@ def download_file_with_progress_bar(url: str,
5757
"""
5858

5959
# Check if the file already exists in the location
60-
file_path = os.path.join(location, file_name)
61-
if os.path.exists(file_path):
60+
location_path = Path(location)
61+
file_path = location_path / file_name
62+
63+
if file_path.exists():
6264
existing_checksum = calculate_checksum(file_path)
6365
if existing_checksum == expected_checksum:
6466
return file_path
@@ -91,10 +93,10 @@ def download_file_with_progress_bar(url: str,
9193
)
9294

9395
# Create the directory and any necessary parent directories
94-
os.makedirs(location, exist_ok=True)
96+
location_path.mkdir(exist_ok=True)
9597

9698
filename = filename_match.group(1)
97-
file_path = os.path.join(location, filename)
99+
file_path = location_path / filename
98100

99101
total_size = int(response.headers.get("Content-Length", 0))
100102
checksum = hashlib.md5() # create checksum
@@ -111,7 +113,7 @@ def download_file_with_progress_bar(url: str,
111113
downloaded_checksum = checksum.hexdigest() # Get the checksum value
112114
if downloaded_checksum != expected_checksum:
113115
warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.")
114-
os.remove(file_path)
116+
file_path.unlink()
115117
warnings.warn("File deleted. Retrying download...")
116118

117119
# Retry download using a for loop

cebra/data/datasets.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,13 @@ class TensorDataset(cebra_data.SingleSessionDataset):
6767
6868
"""
6969

70-
def __init__(
71-
self,
72-
neural: Union[torch.Tensor, npt.NDArray],
73-
continuous: Union[torch.Tensor, npt.NDArray] = None,
74-
discrete: Union[torch.Tensor, npt.NDArray] = None,
75-
offset: int = 1,
76-
):
77-
super().__init__()
70+
def __init__(self,
71+
neural: Union[torch.Tensor, npt.NDArray],
72+
continuous: Union[torch.Tensor, npt.NDArray] = None,
73+
discrete: Union[torch.Tensor, npt.NDArray] = None,
74+
offset: int = 1,
75+
device: str = "cpu"):
76+
super().__init__(device=device)
7877
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
7978
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
8079
self.discrete = self._to_tensor(discrete, torch.LongTensor)
@@ -222,9 +221,9 @@ def __init__(
222221
else:
223222
self._cindex = None
224223
if discrete:
225-
raise NotImplementedError(
226-
"Multisession implementation does not support discrete index yet."
227-
)
224+
self._dindex = torch.cat(list(
225+
self._iter_property("discrete_index")),
226+
dim=0)
228227
else:
229228
self._dindex = None
230229

cebra/data/multi_session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,16 @@ def index(self):
164164

165165
@dataclasses.dataclass
166166
class DiscreteMultiSessionDataLoader(MultiSessionLoader):
167-
pass
167+
"""Contrastive learning conditioned on a discrete behavior variable."""
168+
169+
# Overwrite sampler with the discrete implementation
170+
# Generalize MultisessionSampler to avoid doing this?
171+
def __post_init__(self):
172+
self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset)
173+
174+
@property
175+
def index(self):
176+
return self.dataset.discrete_index
168177

169178

170179
@dataclasses.dataclass

cebra/datasets/allen/ca_movie_decoding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
import glob
3333
import hashlib
34-
import os
3534
import pathlib
3635

3736
import h5py

cebra/datasets/allen/neuropixel_movie.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"""
2929
import glob
3030
import hashlib
31-
import os
3231
import pathlib
3332

3433
import h5py

cebra/datasets/allen/neuropixel_movie_decoding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"""
2929
import glob
3030
import hashlib
31-
import os
3231
import pathlib
3332

3433
import h5py

0 commit comments

Comments
 (0)