Skip to content

Commit bb9d55e

Browse files
authored
Autodownload synthetic data (#68)
1 parent b9a3de3 commit bb9d55e

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

cebra/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def get_datapath(path: str = None) -> str:
8686
from cebra.datasets.gaussian_mixture import *
8787
from cebra.datasets.hippocampus import *
8888
from cebra.datasets.monkey_reaching import *
89+
from cebra.datasets.synthetic_data import *
8990
except ModuleNotFoundError as e:
9091
import warnings
9192

cebra/datasets/synthetic_data.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#
2+
# (c) All rights reserved. ECOLE POLYTECHNIQUE FÉDÉRALE DE LAUSANNE,
3+
# Switzerland, Laboratory of Prof. Mackenzie W. Mathis (UPMWMATHIS) and
4+
# original authors: Steffen Schneider, Jin H Lee, Mackenzie W Mathis. 2023.
5+
#
6+
# Source code:
7+
# https://github.com/AdaptiveMotorControlLab/CEBRA
8+
#
9+
# Please see LICENSE.md for the full license document:
10+
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
11+
#
12+
13+
import os
14+
15+
import joblib
16+
17+
import cebra.data
18+
from cebra.datasets import get_datapath
19+
from cebra.datasets import parametrize
20+
21+
_DEFAULT_DATADIR = get_datapath()
22+
23+
synthetic_data_urls = {
24+
"continuous_label_refractory_poisson": {
25+
"url":
26+
"https://figshare.com/ndownloader/files/41668815?private_link=7439c5302e99db36eebb",
27+
"checksum":
28+
"fcd92bd283c528d5294093190f55ceba"
29+
},
30+
"continuous_label_t": {
31+
"url":
32+
"https://figshare.com/ndownloader/files/41668818?private_link=7439c5302e99db36eebb",
33+
"checksum":
34+
"a6e76f274da571568fd2a4bf4cf48b66"
35+
},
36+
"continuous_label_uniform": {
37+
"url":
38+
"https://figshare.com/ndownloader/files/41668821?private_link=7439c5302e99db36eebb",
39+
"checksum":
40+
"e67400e77ac009e8c9bc958aa5151973"
41+
},
42+
"continuous_label_laplace": {
43+
"url":
44+
"https://figshare.com/ndownloader/files/41668824?private_link=7439c5302e99db36eebb",
45+
"checksum":
46+
"41d7ce4ce8901ae7a5136605ac3f5ffb"
47+
},
48+
"continuous_label_poisson": {
49+
"url":
50+
"https://figshare.com/ndownloader/files/41668827?private_link=7439c5302e99db36eebb",
51+
"checksum":
52+
"a789828f9cca5f3faf36d62ebc4cc8a1"
53+
},
54+
"continuous_label_gaussian": {
55+
"url":
56+
"https://figshare.com/ndownloader/files/41668830?private_link=7439c5302e99db36eebb",
57+
"checksum":
58+
"18d66a2020923e2cd67d2264d20890aa"
59+
},
60+
"continuous_poisson_gaussian_noise": {
61+
"url":
62+
"https://figshare.com/ndownloader/files/41668833?private_link=7439c5302e99db36eebb",
63+
"checksum":
64+
"1a51461820c24a5bcaddaff3991f0ebe"
65+
},
66+
"sim_100d_poisson_cont_label": {
67+
"url":
68+
"https://figshare.com/ndownloader/files/41668836?private_link=7439c5302e99db36eebb",
69+
"checksum":
70+
"306b9c646e7b76a52cfd828612d700cb"
71+
}
72+
}
73+
74+
75+
@parametrize(
76+
"continuous-label-{name}",
77+
name=["t", "uniform", "laplace", "poisson", "gaussian"],
78+
)
79+
class SyntheticData(cebra.data.SingleSessionDataset):
80+
"""
81+
Synthetic datasets with poisson, gaussian, laplace, uniform,
82+
and t noise during generative process.
83+
"""
84+
85+
def __init__(self, name, root=_DEFAULT_DATADIR, download=True):
86+
87+
name = f"continuous_label_{name}"
88+
location = os.path.join(root, "synthetic")
89+
file_path = os.path.join(location, f"{name}.jl")
90+
91+
super().__init__(download=download,
92+
data_url=synthetic_data_urls[name]["url"],
93+
data_checksum=synthetic_data_urls[name]["checksum"],
94+
location=location,
95+
file_name=f"{name}.jl")
96+
97+
data = joblib.load(file_path)
98+
self.data = data #NOTE: making it backwards compatible with synth notebook.
99+
self.name = name
100+
self.neural = self.data['z']
101+
self.latents = self.data['x']
102+
self.u = self.data['u']
103+
self.lam = self.data['lam']
104+
105+
@property
106+
def input_dimension(self):
107+
return self.neural.size(1)
108+
109+
@property
110+
def continuous_index(self):
111+
return self.index
112+
113+
def __getitem__(self, index):
114+
"""Return [ No.Samples x Neurons x 10 ]"""
115+
index = self.expand_index(index)
116+
return self.neural[index].transpose(2, 1)
117+
118+
def __len__(self):
119+
return len(self.neural)
120+
121+
def __repr__(self):
122+
return f"SyntheticData(name: {self.name}, shape: {self.neural.shape})"

0 commit comments

Comments
 (0)