Skip to content

Commit 977ca7f

Browse files
gonlairoMMathisLab
andauthored
Autodownload monkey reaching data (#59)
* add two datasets to monkey-reaching * add passive_all.jl * add rest of urls + train/test/valid split * fix typo * add urls to test --------- Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent 16653e2 commit 977ca7f

File tree

2 files changed

+112
-7
lines changed

2 files changed

+112
-7
lines changed

cebra/datasets/monkey_reaching.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,82 @@ def _get_info(trial_info, data):
132132
return data_dic
133133

134134

135+
monkey_reaching_urls = {
136+
"all_all.jl": {
137+
"url":
138+
"https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914",
139+
"checksum":
140+
"dea556301fa4fafa86e28cf8621cab5a"
141+
},
142+
"all_train.jl": {
143+
"url":
144+
"https://figshare.com/ndownloader/files/41668752?private_link=6fa4ee74a8f465ec7914",
145+
"checksum":
146+
"e280e4cd86969e6fd8bfd3a8f402b2fe"
147+
},
148+
"all_test.jl": {
149+
"url":
150+
"https://figshare.com/ndownloader/files/41668761?private_link=6fa4ee74a8f465ec7914",
151+
"checksum":
152+
"25d3ff2c15014db8b8bf2543482ae881"
153+
},
154+
"all_valid.jl": {
155+
"url":
156+
"https://figshare.com/ndownloader/files/41668755?private_link=6fa4ee74a8f465ec7914",
157+
"checksum":
158+
"8cd25169d31f83ae01b03f7b1b939723"
159+
},
160+
"active_all.jl": {
161+
"url":
162+
"https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914",
163+
"checksum":
164+
"c626acea5062122f5a68ef18d3e45e51"
165+
},
166+
"active_train.jl": {
167+
"url":
168+
"https://figshare.com/ndownloader/files/41668770?private_link=6fa4ee74a8f465ec7914",
169+
"checksum":
170+
"72a48056691078eee22c36c1992b1d37"
171+
},
172+
"active_test.jl": {
173+
"url":
174+
"https://figshare.com/ndownloader/files/41668773?private_link=6fa4ee74a8f465ec7914",
175+
"checksum":
176+
"35b7e060008a8722c536584c4748f2ea"
177+
},
178+
"active_valid.jl": {
179+
"url":
180+
"https://figshare.com/ndownloader/files/41668767?private_link=6fa4ee74a8f465ec7914",
181+
"checksum":
182+
"dd58eb1e589361b4132f34b22af56b79"
183+
},
184+
"passive_all.jl": {
185+
"url":
186+
"https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914",
187+
"checksum":
188+
"bbb1bc9d8eec583a46f6673470fc98ad"
189+
},
190+
"passive_train.jl": {
191+
"url":
192+
"https://figshare.com/ndownloader/files/41668743?private_link=6fa4ee74a8f465ec7914",
193+
"checksum":
194+
"f22e05a69f70e18ba823a0a89162a45c"
195+
},
196+
"passive_test.jl": {
197+
"url":
198+
"https://figshare.com/ndownloader/files/41668746?private_link=6fa4ee74a8f465ec7914",
199+
"checksum":
200+
"42453ae3e4fd27d82d297f78c13cd6b7"
201+
},
202+
"passive_valid.jl": {
203+
"url":
204+
"https://figshare.com/ndownloader/files/41668749?private_link=6fa4ee74a8f465ec7914",
205+
"checksum":
206+
"2dcc10c27631b95a075eaa2d2297bb4a"
207+
}
208+
}
209+
210+
135211
@register("area2-bump")
136212
class Area2BumpDataset(cebra.data.SingleSessionDataset):
137213
"""Base dataclass to generate monkey reaching datasets.
@@ -151,19 +227,30 @@ class Area2BumpDataset(cebra.data.SingleSessionDataset):
151227
152228
"""
153229

154-
def __init__(
155-
self,
156-
path: str = get_datapath("monkey_reaching_preload_smth_40/"),
157-
session: str = "active",
158-
):
230+
def __init__(self,
231+
path: str = get_datapath("monkey_reaching_preload_smth_40/"),
232+
session: str = "active",
233+
download=True):
159234
super().__init__()
160235
self.path = path
236+
self.download = download
161237
self.session = session
162238
if session == "active-passive":
163239
self.load_session = "all"
164240
else:
165241
self.load_session = session
166-
self.data = jl.load(os.path.join(path, f"{self.load_session}_all.jl"))
242+
243+
super().__init__(
244+
download=self.download,
245+
data_url=monkey_reaching_urls[f"{self.load_session}_all.jl"]["url"],
246+
data_checksum=monkey_reaching_urls[f"{self.load_session}_all.jl"]
247+
["checksum"],
248+
location=self.path,
249+
file_name=f"{self.load_session}_all.jl",
250+
)
251+
252+
self.data = jl.load(
253+
os.path.join(self.path, f"{self.load_session}_all.jl"))
167254
self._post_load()
168255

169256
def split(self, split):
@@ -179,6 +266,15 @@ def split(self, split):
179266
180267
"""
181268

269+
super().__init__(
270+
download=self.download,
271+
data_url=monkey_reaching_urls[f"{self.load_session}_{split}.jl"]
272+
["url"],
273+
data_checksum=monkey_reaching_urls[
274+
f"{self.load_session}_{split}.jl"]["checksum"],
275+
location=self.path,
276+
file_name=f"{self.load_session}_{split}.jl",
277+
)
182278
self.data = jl.load(
183279
os.path.join(self.path, f"{self.load_session}_{split}.jl"))
184280
self._post_load()

tests/test_datasets.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,16 @@ def parametrize_data(function):
299299
"a83b02dbdc884fdd7e53df362499d42f"),
300300
("gatsby.jl",
301301
"https://figshare.com/ndownloader/files/40849454?private_link=9f91576cbbcc8b0d8828",
302-
"2b889da48178b3155011c12555342813")
302+
"2b889da48178b3155011c12555342813"),
303+
("all_all.jl",
304+
"https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914",
305+
"dea556301fa4fafa86e28cf8621cab5a"),
306+
("active_all.jl",
307+
"https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914",
308+
"c626acea5062122f5a68ef18d3e45e51"),
309+
("passive_all.jl",
310+
"https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914",
311+
"bbb1bc9d8eec583a46f6673470fc98ad"),
303312
])(function)
304313

305314

0 commit comments

Comments
 (0)