@@ -1062,9 +1062,64 @@ def samples2times(self, values, direction='forward'):
10621062
10631063@dataclass
10641064class SessionLoader :
1065+ """
1066+ Object to load session data for a give session in the recommended way.
1067+
1068+ Parameters
1069+ ----------
1070+ one: one.api.ONE instance
1071+ Can be in remote or local mode (required)
1072+ session_path: string or pathlib.Path
1073+ The absolute path to the session (one of session_path or eid is required)
1074+ eid: string
1075+ database UUID of the session (one of session_path or eid is required)
1076+
1077+ If both are provided, session_path takes precedence over eid.
1078+
1079+ Examples
1080+ --------
1081+ 1) Load all available session data for one session:
1082+ >>> from one.api import ONE
1083+ >>> from brainbox.io.one import SessionLoader
1084+ >>> one = ONE()
1085+ >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1086+ # Object is initiated, but no data is loaded as you can see in the data_info attribute
1087+ >>> sess_loader.data_info
1088+ name is_loaded
1089+ 0 trials False
1090+ 1 wheel False
1091+ 2 poses False
1092+ 3 motion_energy False
1093+ 4 pupil False
1094+
1095+ # Loading all available session data, the data_info attribute now shows which data has been loaded
1096+ >>> sess_loader.load_session_data()
1097+ >>> sess_loader.data_info
1098+ name is_loaded
1099+ 0 trials True
1100+ 1 wheel True
1101+ 2 poses True
1102+ 3 motion_energy True
1103+ 4 pupil False
1104+
1105+ # You can access the data via the respective attributes, e.g.
1106+ >>> sess_loader.trials.shape
1107+ (626, 18)
1108+ # Each data comes with its own timestamps in a column called 'times'
1109+ >>> sess_loader.pose['bodyCamera']['times']
1110+ 0 6.201239
1111+ 1 6.234569
1112+ 2 6.267899
1113+ 3 6.301229
1114+ 4 6.334592
1115+ ...
1116+ # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1117+ functions:
1118+ >>> sess_loader.load_wheel(sampling_rate=100)
1119+ """
10651120 one : One = None
1066- eid : str = ''
10671121 session_path : Path = ''
1122+ eid : str = ''
10681123 data_info : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
10691124 trials : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
10701125 wheel : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
@@ -1073,12 +1128,17 @@ class SessionLoader:
10731128 pupil : pd .DataFrame = field (default_factory = pd .DataFrame , repr = False )
10741129
10751130 def __post_init__ (self ):
1131+ """
1132+ Function that runs automatically after initiation of the dataclass attributes.
1133+ Checks for required inputs, sets session_path and eid, creates data_info table.
1134+ """
10761135 if self .one is None :
10771136 raise ValueError ("An input to one is required. If not connection to a database is desired, it can be "
10781137 "a fully local instance of One." )
10791138 # If session path is given, takes precedence over eid
10801139 if self .session_path is not None and self .session_path != '' :
10811140 self .eid = self .one .to_eid (self .session_path )
1141+ self .session_path = Path (self .session_path )
10821142 # Providing no session path, try to infer from eid
10831143 else :
10841144 if self .eid is not None and self .eid != '' :
@@ -1096,7 +1156,30 @@ def __post_init__(self):
10961156 self .data_info = pd .DataFrame (columns = ['name' , 'is_loaded' ], data = zip (data_names , [False ]* len (data_names )))
10971157
10981158 def load_session_data (self , trials = True , wheel = True , poses = True , motion_energy = True , pupil = True , reload = False ):
1099-
1159+ """
1160+ Function to load available session data into the SessionLoader object. Input parameters allow to control which
1161+ data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1162+ parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1163+ in SessionLoader.data_info
1164+
1165+ Parameters
1166+ ----------
1167+ trials: boolean
1168+ Whether to load all trials data into SessionLoader.trials, default is True
1169+ wheel: boolean
1170+ Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1171+ poses: boolean
1172+ Whether to load pose tracking results (DLC) for each available camera into SessionLoader.poses,
1173+ default is True
1174+ motion_energy: boolean
1175+ Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1176+ into SessionLoader.motion_energy, default is True
1177+ pupil: boolean
1178+ Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1179+ default is True
1180+ reload: boolean
1181+ Whether to reload data that has already been loaded into this SessionLoader object, default is False
1182+ """
11001183 load_df = self .data_info .copy ()
11011184 load_df ['to_load' ] = [
11021185 trials ,
@@ -1128,10 +1211,25 @@ def load_session_data(self, trials=True, wheel=True, poses=True, motion_energy=T
11281211 _logger .debug (e )
11291212
11301213 def load_trials (self ):
1214+ """
1215+ Function to load trials data into SessionLoader.trials
1216+ """
11311217 self .trials = self .one .load_object (self .eid , 'trials' ).to_df ()
11321218 self .data_info .loc [self .data_info ['name' ] == 'trials' , 'is_loaded' ] = True
11331219
11341220 def load_wheel (self , sampling_rate = 1000 , smooth_size = 0.03 ):
1221+ """
1222+ Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1223+ is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1224+ smoothing is applied.
1225+
1226+ Parameters
1227+ ----------
1228+ sampling_rate: float
1229+ Rate at which to sample the wheel position
1230+ smooth_size: float
1231+ Kernel for smoothing the wheel data to compute velocity and acceleration
1232+ """
11351233 wheel_raw = self .one .load_object (self .eid , 'wheel' )
11361234 # TODO: Fix this instead of raising error?
11371235 if wheel_raw ['position' ].shape [0 ] != wheel_raw ['timestamps' ].shape [0 ]:
@@ -1145,6 +1243,17 @@ def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
11451243 self .data_info .loc [self .data_info ['name' ] == 'wheel' , 'is_loaded' ] = True
11461244
11471245 def load_pose (self , likelihood_thr = 0.9 , views = ['left' , 'right' , 'body' ]):
1246+ """
1247+ Function to load the pose estimation results (DLC) into SessionLoader.poses
1248+ Parameters
1249+ ----------
1250+ likelihood_thr
1251+ views
1252+
1253+ Returns
1254+ -------
1255+
1256+ """
11481257 for view in views :
11491258 try :
11501259 pose_raw = self .one .load_object (self .eid , f'{ view } Camera' , attribute = ['dlc' , 'times' ])
@@ -1215,23 +1324,6 @@ def load_pupil(self, snr_thresh=5):
12151324 self .pupil = pd .DataFrame
12161325 raise ValueError (f'Pupil diameter SNR ({ snr :.2f} ) below threshold SNR ({ snr_thresh } ), removing data.' )
12171326
1218- def align_trials_to_event (self , align_event = 'stimOn_times' , pre_event = 0.5 , post_event = 0.5 ):
1219- possible_events = ['stimOn_times' , 'goCue_times' , 'goCueTrigger_times' ,
1220- 'response_times' , 'feedback_times' , 'firstMovement_times' ]
1221- if align_event not in possible_events :
1222- raise ValueError (f"Argument align_event must be on of { possible_events } " )
1223-
1224- if self .trials .shape == (0 , 0 ):
1225- _logger .info ("No trials data loaded. Trying to load trials data." )
1226- self .load_trials ()
1227-
1228- align_str = f"align_{ align_event .split ('_' )[0 ]} "
1229- self .trials [f'{ align_str } _start' ] = self .trials [align_event ] - pre_event
1230- self .trials [f'{ align_str } _end' ] = self .trials [align_event ] + post_event
1231- diffs = self .trials [f'{ align_str } _end' ] - np .roll (self .trials [f'{ align_str } _start' ], - 1 )
1232- if np .any (diffs [:- 1 ] > 0 ):
1233- _logger .warning (f'{ sum (diffs [:- 1 ] > 0 )} trials overlapping, try reducing pre_event, post_event or both!' )
1234-
12351327 def _check_video_timestamps (self , view , video_timestamps , video_data ):
12361328 # If camera times are shorter than video data, or empty, no current fix
12371329 if video_timestamps .shape [0 ] < video_data .shape [0 ]:
0 commit comments