Skip to content

Commit 9e14790

Browse files
authored
Add OS indepdent paths for monkey, hippocampus and synthetic data (#169)
1 parent 65d10a2 commit 9e14790

13 files changed

+40
-33
lines changed

cebra/data/assets.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
#
2222

2323
import hashlib
24-
import os
2524
import re
2625
import warnings
26+
from pathlib import Path
2727
from typing import Optional
2828

2929
import requests
@@ -57,8 +57,10 @@ def download_file_with_progress_bar(url: str,
5757
"""
5858

5959
# Check if the file already exists in the location
60-
file_path = os.path.join(location, file_name)
61-
if os.path.exists(file_path):
60+
location_path = Path(location)
61+
file_path = location_path / file_name
62+
63+
if file_path.exists():
6264
existing_checksum = calculate_checksum(file_path)
6365
if existing_checksum == expected_checksum:
6466
return file_path
@@ -91,10 +93,10 @@ def download_file_with_progress_bar(url: str,
9193
)
9294

9395
# Create the directory and any necessary parent directories
94-
os.makedirs(location, exist_ok=True)
96+
location_path.mkdir(exist_ok=True)
9597

9698
filename = filename_match.group(1)
97-
file_path = os.path.join(location, filename)
99+
file_path = location_path / filename
98100

99101
total_size = int(response.headers.get("Content-Length", 0))
100102
checksum = hashlib.md5() # create checksum
@@ -111,7 +113,7 @@ def download_file_with_progress_bar(url: str,
111113
downloaded_checksum = checksum.hexdigest() # Get the checksum value
112114
if downloaded_checksum != expected_checksum:
113115
warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.")
114-
os.remove(file_path)
116+
file_path.unlink()
115117
warnings.warn("File deleted. Retrying download...")
116118

117119
# Retry download using a for loop

cebra/data/datasets.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,12 @@ class TensorDataset(cebra_data.SingleSessionDataset):
6767
6868
"""
6969

70-
def __init__(
71-
self,
72-
neural: Union[torch.Tensor, npt.NDArray],
73-
continuous: Union[torch.Tensor, npt.NDArray] = None,
74-
discrete: Union[torch.Tensor, npt.NDArray] = None,
75-
offset: int = 1,
76-
device: str = "cpu"
77-
):
70+
def __init__(self,
71+
neural: Union[torch.Tensor, npt.NDArray],
72+
continuous: Union[torch.Tensor, npt.NDArray] = None,
73+
discrete: Union[torch.Tensor, npt.NDArray] = None,
74+
offset: int = 1,
75+
device: str = "cpu"):
7876
super().__init__(device=device)
7977
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
8078
self.continuous = self._to_tensor(continuous, torch.FloatTensor)

cebra/datasets/allen/ca_movie_decoding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
import glob
3333
import hashlib
34-
import os
3534
import pathlib
3635

3736
import h5py

cebra/datasets/allen/neuropixel_movie.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"""
2929
import glob
3030
import hashlib
31-
import os
3231
import pathlib
3332

3433
import h5py

cebra/datasets/allen/neuropixel_movie_decoding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"""
2929
import glob
3030
import hashlib
31-
import os
3231
import pathlib
3332

3433
import h5py

cebra/datasets/gaussian_mixture.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22+
import pathlib
2223
from typing import Tuple
2324

2425
import joblib as jl
@@ -51,7 +52,8 @@ def __init__(self, noise: str = "poisson"):
5152
super().__init__()
5253
self.noise = noise
5354
data = jl.load(
54-
get_datapath(f"synthetic/continuous_label_{self.noise}.jl"))
55+
pathlib.Path(_DEFAULT_DATADIR) / "synthetic" /
56+
f"continuous_label_{self.noise}.jl")
5557
self.latent = data["z"]
5658
self.index = torch.from_numpy(data["u"]).float()
5759
self.neural = torch.from_numpy(data["x"]).float()

cebra/datasets/generate_synthetic_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Adapted from pi-VAE: https://github.com/zhd96/pi-vae/blob/main/code/pi_vae.py
2626
"""
2727
import argparse
28-
import os
28+
import pathlib
2929
import sys
3030

3131
import joblib as jl
@@ -245,5 +245,5 @@ def refractory_poisson(x):
245245
"lam": lam_true,
246246
"x": x
247247
},
248-
os.path.join(args.save_path, f"continuous_label_{args.noise}.jl"),
248+
pathlib.Path(args.save_path) / f"continuous_label_{args.noise}.jl",
249249
)

cebra/datasets/hippocampus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"""
3333

3434
import hashlib
35-
import os
35+
import pathlib
3636

3737
import joblib
3838
import numpy as np
@@ -94,8 +94,8 @@ class SingleRatDataset(cebra.data.SingleSessionDataset):
9494
"""
9595

9696
def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True):
97-
location = os.path.join(root, "rat_hippocampus")
98-
file_path = os.path.join(location, f"{name}.jl")
97+
location = pathlib.Path(root) / "rat_hippocampus"
98+
file_path = location / f"{name}.jl"
9999

100100
super().__init__(download=download,
101101
data_url=rat_dataset_urls[name]["url"],

cebra/datasets/monkey_reaching.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
"""
3030

3131
import hashlib
32-
import os
32+
import pathlib
3333
import pickle as pk
34+
from typing import Union
3435

3536
import joblib as jl
3637
import numpy as np
@@ -41,10 +42,11 @@
4142
from cebra.datasets import get_datapath
4243
from cebra.datasets import register
4344

45+
_DEFAULT_DATADIR = get_datapath()
46+
4447

4548
def _load_data(
46-
path: str = get_datapath(
47-
"s1_reaching/sub-Han_desc-train_behavior+ecephys.nwb"),
49+
path: Union[str, pathlib.Path] = None,
4850
session: str = "active",
4951
split: str = "train",
5052
):
@@ -61,6 +63,13 @@ def _load_data(
6163
6264
"""
6365

66+
if path is None:
67+
path = pathlib.Path(
68+
_DEFAULT_DATADIR
69+
) / "s1_reaching" / "sub-Han_desc-train_behavior+ecephys.nwb"
70+
else:
71+
path = pathlib.Path(path)
72+
6473
try:
6574
from nlb_tools.nwb_interface import NWBDataset
6675
except ImportError as e:
@@ -259,7 +268,7 @@ def __init__(self,
259268
)
260269

261270
self.data = jl.load(
262-
os.path.join(self.path, f"{self.load_session}_all.jl"))
271+
pathlib.Path(self.path) / f"{self.load_session}_all.jl")
263272
self._post_load()
264273

265274
def split(self, split):
@@ -285,7 +294,7 @@ def split(self, split):
285294
file_name=f"{self.load_session}_{split}.jl",
286295
)
287296
self.data = jl.load(
288-
os.path.join(self.path, f"{self.load_session}_{split}.jl"))
297+
pathlib.Path(self.path) / f"{self.load_session}_{split}.jl")
289298
self._post_load()
290299

291300
def _post_load(self):
@@ -407,7 +416,7 @@ def _create_area2_dataset():
407416
408417
"""
409418

410-
PATH = get_datapath("monkey_reaching_preload_smth_40")
419+
PATH = pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40"
411420
for session_type in ["active", "passive", "active-passive", "all"]:
412421

413422
@register(f"area2-bump-pos-{session_type}")
@@ -506,7 +515,7 @@ def _create_area2_shuffled_dataset():
506515
507516
"""
508517

509-
PATH = get_datapath("monkey_reaching_preload_smth_40/")
518+
PATH = pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40"
510519
for session_type in ["active", "active-passive"]:
511520

512521
@register(f"area2-bump-pos-{session_type}-shuffled-trial")

cebra/solver/multi_session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""Solver implementations for multi-session datasetes."""
2323

2424
import abc
25-
import os
2625
from collections.abc import Iterable
2726
from typing import List, Optional
2827

0 commit comments

Comments
 (0)