2929"""
3030
3131import hashlib
32- import os
32+ import pathlib
3333import pickle as pk
34+ from typing import Union
3435
3536import joblib as jl
3637import numpy as np
4142from cebra .datasets import get_datapath
4243from cebra .datasets import register
4344
45+ _DEFAULT_DATADIR = get_datapath ()
46+
4447
4548def _load_data (
46- path : str = get_datapath (
47- "s1_reaching/sub-Han_desc-train_behavior+ecephys.nwb" ),
49+ path : Union [str , pathlib .Path ] = None ,
4850 session : str = "active" ,
4951 split : str = "train" ,
5052):
@@ -61,6 +63,13 @@ def _load_data(
6163
6264 """
6365
66+ if path is None :
67+ path = pathlib .Path (
68+ _DEFAULT_DATADIR
69+ ) / "s1_reaching" / "sub-Han_desc-train_behavior+ecephys.nwb"
70+ else :
71+ path = pathlib .Path (path )
72+
6473 try :
6574 from nlb_tools .nwb_interface import NWBDataset
6675 except ImportError as e :
@@ -259,7 +268,7 @@ def __init__(self,
259268 )
260269
261270 self .data = jl .load (
262- os . path . join (self .path , f"{ self .load_session } _all.jl" ) )
271+ pathlib . Path (self .path ) / f"{ self .load_session } _all.jl" )
263272 self ._post_load ()
264273
265274 def split (self , split ):
@@ -285,7 +294,7 @@ def split(self, split):
285294 file_name = f"{ self .load_session } _{ split } .jl" ,
286295 )
287296 self .data = jl .load (
288- os . path . join (self .path , f"{ self .load_session } _{ split } .jl" ) )
297+ pathlib . Path (self .path ) / f"{ self .load_session } _{ split } .jl" )
289298 self ._post_load ()
290299
291300 def _post_load (self ):
@@ -407,7 +416,7 @@ def _create_area2_dataset():
407416
408417 """
409418
410- PATH = get_datapath ( "monkey_reaching_preload_smth_40" )
419+ PATH = pathlib . Path ( _DEFAULT_DATADIR ) / "monkey_reaching_preload_smth_40"
411420 for session_type in ["active" , "passive" , "active-passive" , "all" ]:
412421
413422 @register (f"area2-bump-pos-{ session_type } " )
@@ -506,7 +515,7 @@ def _create_area2_shuffled_dataset():
506515
507516 """
508517
509- PATH = get_datapath ( "monkey_reaching_preload_smth_40/" )
518+ PATH = pathlib . Path ( _DEFAULT_DATADIR ) / "monkey_reaching_preload_smth_40"
510519 for session_type in ["active" , "active-passive" ]:
511520
512521 @register (f"area2-bump-pos-{ session_type } -shuffled-trial" )
0 commit comments