@@ -132,6 +132,82 @@ def _get_info(trial_info, data):
132132 return data_dic
133133
134134
135+ monkey_reaching_urls = {
136+ "all_all.jl" : {
137+ "url" :
138+ "https://figshare.com/ndownloader/files/41668764?private_link=6fa4ee74a8f465ec7914" ,
139+ "checksum" :
140+ "dea556301fa4fafa86e28cf8621cab5a"
141+ },
142+ "all_train.jl" : {
143+ "url" :
144+ "https://figshare.com/ndownloader/files/41668752?private_link=6fa4ee74a8f465ec7914" ,
145+ "checksum" :
146+ "e280e4cd86969e6fd8bfd3a8f402b2fe"
147+ },
148+ "all_test.jl" : {
149+ "url" :
150+ "https://figshare.com/ndownloader/files/41668761?private_link=6fa4ee74a8f465ec7914" ,
151+ "checksum" :
152+ "25d3ff2c15014db8b8bf2543482ae881"
153+ },
154+ "all_valid.jl" : {
155+ "url" :
156+ "https://figshare.com/ndownloader/files/41668755?private_link=6fa4ee74a8f465ec7914" ,
157+ "checksum" :
158+ "8cd25169d31f83ae01b03f7b1b939723"
159+ },
160+ "active_all.jl" : {
161+ "url" :
162+ "https://figshare.com/ndownloader/files/41668776?private_link=6fa4ee74a8f465ec7914" ,
163+ "checksum" :
164+ "c626acea5062122f5a68ef18d3e45e51"
165+ },
166+ "active_train.jl" : {
167+ "url" :
168+ "https://figshare.com/ndownloader/files/41668770?private_link=6fa4ee74a8f465ec7914" ,
169+ "checksum" :
170+ "72a48056691078eee22c36c1992b1d37"
171+ },
172+ "active_test.jl" : {
173+ "url" :
174+ "https://figshare.com/ndownloader/files/41668773?private_link=6fa4ee74a8f465ec7914" ,
175+ "checksum" :
176+ "35b7e060008a8722c536584c4748f2ea"
177+ },
178+ "active_valid.jl" : {
179+ "url" :
180+ "https://figshare.com/ndownloader/files/41668767?private_link=6fa4ee74a8f465ec7914" ,
181+ "checksum" :
182+ "dd58eb1e589361b4132f34b22af56b79"
183+ },
184+ "passive_all.jl" : {
185+ "url" :
186+ "https://figshare.com/ndownloader/files/41668758?private_link=6fa4ee74a8f465ec7914" ,
187+ "checksum" :
188+ "bbb1bc9d8eec583a46f6673470fc98ad"
189+ },
190+ "passive_train.jl" : {
191+ "url" :
192+ "https://figshare.com/ndownloader/files/41668743?private_link=6fa4ee74a8f465ec7914" ,
193+ "checksum" :
194+ "f22e05a69f70e18ba823a0a89162a45c"
195+ },
196+ "passive_test.jl" : {
197+ "url" :
198+ "https://figshare.com/ndownloader/files/41668746?private_link=6fa4ee74a8f465ec7914" ,
199+ "checksum" :
200+ "42453ae3e4fd27d82d297f78c13cd6b7"
201+ },
202+ "passive_valid.jl" : {
203+ "url" :
204+ "https://figshare.com/ndownloader/files/41668749?private_link=6fa4ee74a8f465ec7914" ,
205+ "checksum" :
206+ "2dcc10c27631b95a075eaa2d2297bb4a"
207+ }
208+ }
209+
210+
135211@register ("area2-bump" )
136212class Area2BumpDataset (cebra .data .SingleSessionDataset ):
137213 """Base dataclass to generate monkey reaching datasets.
@@ -151,19 +227,30 @@ class Area2BumpDataset(cebra.data.SingleSessionDataset):
151227
152228 """
153229
154- def __init__ (
155- self ,
156- path : str = get_datapath ("monkey_reaching_preload_smth_40/" ),
157- session : str = "active" ,
158- ):
230+ def __init__ (self ,
231+ path : str = get_datapath ("monkey_reaching_preload_smth_40/" ),
232+ session : str = "active" ,
233+ download = True ):
159234 super ().__init__ ()
160235 self .path = path
236+ self .download = download
161237 self .session = session
162238 if session == "active-passive" :
163239 self .load_session = "all"
164240 else :
165241 self .load_session = session
166- self .data = jl .load (os .path .join (path , f"{ self .load_session } _all.jl" ))
242+
243+ super ().__init__ (
244+ download = self .download ,
245+ data_url = monkey_reaching_urls [f"{ self .load_session } _all.jl" ]["url" ],
246+ data_checksum = monkey_reaching_urls [f"{ self .load_session } _all.jl" ]
247+ ["checksum" ],
248+ location = self .path ,
249+ file_name = f"{ self .load_session } _all.jl" ,
250+ )
251+
252+ self .data = jl .load (
253+ os .path .join (self .path , f"{ self .load_session } _all.jl" ))
167254 self ._post_load ()
168255
169256 def split (self , split ):
@@ -179,6 +266,15 @@ def split(self, split):
179266
180267 """
181268
269+ super ().__init__ (
270+ download = self .download ,
271+ data_url = monkey_reaching_urls [f"{ self .load_session } _{ split } .jl" ]
272+ ["url" ],
273+ data_checksum = monkey_reaching_urls [
274+ f"{ self .load_session } _{ split } .jl" ]["checksum" ],
275+ location = self .path ,
276+ file_name = f"{ self .load_session } _{ split } .jl" ,
277+ )
182278 self .data = jl .load (
183279 os .path .join (self .path , f"{ self .load_session } _{ split } .jl" ))
184280 self ._post_load ()
0 commit comments