|
1 | 1 | """Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" |
2 | | -from dataclasses import dataclass |
| 2 | +from dataclasses import dataclass, field |
3 | 3 | import logging |
4 | 4 | import os |
5 | 5 | from pathlib import Path |
|
24 | 24 | from brainbox.core import TimeSeries |
25 | 25 | from brainbox.processing import sync |
26 | 26 | from brainbox.metrics.single_units import quick_unit_metrics |
| 27 | +from brainbox.behavior.wheel import interpolate_position, velocity_smoothed |
| 28 | +from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter |
27 | 29 |
|
28 | 30 | _logger = logging.getLogger('ibllib') |
29 | 31 |
|
@@ -1056,3 +1058,334 @@ def samples2times(self, values, direction='forward'): |
1056 | 1058 | 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), |
1057 | 1059 | } |
1058 | 1060 | return self._sync[direction](values) |
| 1061 | + |
| 1062 | + |
| 1063 | +@dataclass |
| 1064 | +class SessionLoader: |
| 1065 | + """ |
| 1066 | + Object to load session data for a give session in the recommended way. |
| 1067 | +
|
| 1068 | + Parameters |
| 1069 | + ---------- |
| 1070 | + one: one.api.ONE instance |
| 1071 | + Can be in remote or local mode (required) |
| 1072 | + session_path: string or pathlib.Path |
| 1073 | + The absolute path to the session (one of session_path or eid is required) |
| 1074 | + eid: string |
| 1075 | + database UUID of the session (one of session_path or eid is required) |
| 1076 | +
|
| 1077 | + If both are provided, session_path takes precedence over eid. |
| 1078 | +
|
| 1079 | + Examples |
| 1080 | + -------- |
| 1081 | + 1) Load all available session data for one session: |
| 1082 | + >>> from one.api import ONE |
| 1083 | + >>> from brainbox.io.one import SessionLoader |
| 1084 | + >>> one = ONE() |
| 1085 | + >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') |
| 1086 | + # Object is initiated, but no data is loaded as you can see in the data_info attribute |
| 1087 | + >>> sess_loader.data_info |
| 1088 | + name is_loaded |
| 1089 | + 0 trials False |
| 1090 | + 1 wheel False |
| 1091 | + 2 pose False |
| 1092 | + 3 motion_energy False |
| 1093 | + 4 pupil False |
| 1094 | +
|
| 1095 | + # Loading all available session data, the data_info attribute now shows which data has been loaded |
| 1096 | + >>> sess_loader.load_session_data() |
| 1097 | + >>> sess_loader.data_info |
| 1098 | + name is_loaded |
| 1099 | + 0 trials True |
| 1100 | + 1 wheel True |
| 1101 | + 2 pose True |
| 1102 | + 3 motion_energy True |
| 1103 | + 4 pupil False |
| 1104 | +
|
| 1105 | + # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. |
| 1106 | + >>> type(sess_loader.trials) |
| 1107 | + pandas.core.frame.DataFrame |
| 1108 | + >>> sess_loader.trials.shape |
| 1109 | + (626, 18) |
| 1110 | + # Each data comes with its own timestamps in a column called 'times' |
| 1111 | + >>> sess_loader.wheel['times'] |
| 1112 | + 0 0.134286 |
| 1113 | + 1 0.135286 |
| 1114 | + 2 0.136286 |
| 1115 | + 3 0.137286 |
| 1116 | + 4 0.138286 |
| 1117 | + ... |
| 1118 | + # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. |
| 1119 | + # The dataframes of all cameras are collected in a dictionary |
| 1120 | + >>> type(sess_loader.pose) |
| 1121 | + dict |
| 1122 | + >>> sess_loader.pose.keys() |
| 1123 | + dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) |
| 1124 | + >>> sess_loader.pose['bodyCamera'].columns |
| 1125 | + Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') |
| 1126 | + # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading |
| 1127 | + functions: |
| 1128 | + >>> sess_loader.load_wheel(sampling_rate=100) |
| 1129 | + """ |
| 1130 | + one: One = None |
| 1131 | + session_path: Path = '' |
| 1132 | + eid: str = '' |
| 1133 | + data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) |
| 1134 | + trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) |
| 1135 | + wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) |
| 1136 | + pose: dict = field(default_factory=dict, repr=False) |
| 1137 | + motion_energy: dict = field(default_factory=dict, repr=False) |
| 1138 | + pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) |
| 1139 | + |
| 1140 | + def __post_init__(self): |
| 1141 | + """ |
| 1142 | + Function that runs automatically after initiation of the dataclass attributes. |
| 1143 | + Checks for required inputs, sets session_path and eid, creates data_info table. |
| 1144 | + """ |
| 1145 | + if self.one is None: |
| 1146 | + raise ValueError("An input to one is required. If not connection to a database is desired, it can be " |
| 1147 | + "a fully local instance of One.") |
| 1148 | + # If session path is given, takes precedence over eid |
| 1149 | + if self.session_path is not None and self.session_path != '': |
| 1150 | + self.eid = self.one.to_eid(self.session_path) |
| 1151 | + self.session_path = Path(self.session_path) |
| 1152 | + # Providing no session path, try to infer from eid |
| 1153 | + else: |
| 1154 | + if self.eid is not None and self.eid != '': |
| 1155 | + self.session_path = self.one.eid2path(self.eid) |
| 1156 | + else: |
| 1157 | + raise ValueError("If no session path is given, eid is required.") |
| 1158 | + |
| 1159 | + data_names = [ |
| 1160 | + 'trials', |
| 1161 | + 'wheel', |
| 1162 | + 'pose', |
| 1163 | + 'motion_energy', |
| 1164 | + 'pupil' |
| 1165 | + ] |
| 1166 | + self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) |
| 1167 | + |
| 1168 | + def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): |
| 1169 | + """ |
| 1170 | + Function to load available session data into the SessionLoader object. Input parameters allow to control which |
| 1171 | + data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input |
| 1172 | + parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored |
| 1173 | + in SessionLoader.data_info |
| 1174 | +
|
| 1175 | + Parameters |
| 1176 | + ---------- |
| 1177 | + trials: boolean |
| 1178 | + Whether to load all trials data into SessionLoader.trials, default is True |
| 1179 | + wheel: boolean |
| 1180 | + Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True |
| 1181 | + pose: boolean |
| 1182 | + Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, |
| 1183 | + default is True |
| 1184 | + motion_energy: boolean |
| 1185 | + Whether to load motion energy data (whisker pad for left/right camera, body for body camera) |
| 1186 | + into SessionLoader.motion_energy, default is True |
| 1187 | + pupil: boolean |
| 1188 | + Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, |
| 1189 | + default is True |
| 1190 | + reload: boolean |
| 1191 | + Whether to reload data that has already been loaded into this SessionLoader object, default is False |
| 1192 | + """ |
| 1193 | + load_df = self.data_info.copy() |
| 1194 | + load_df['to_load'] = [ |
| 1195 | + trials, |
| 1196 | + wheel, |
| 1197 | + pose, |
| 1198 | + motion_energy, |
| 1199 | + pupil |
| 1200 | + ] |
| 1201 | + load_df['load_func'] = [ |
| 1202 | + self.load_trials, |
| 1203 | + self.load_wheel, |
| 1204 | + self.load_pose, |
| 1205 | + self.load_motion_energy, |
| 1206 | + self.load_pupil |
| 1207 | + ] |
| 1208 | + |
| 1209 | + for idx, row in load_df.iterrows(): |
| 1210 | + if row['to_load'] is False: |
| 1211 | + _logger.debug(f"Not loading {row['name']} data, set to False.") |
| 1212 | + elif row['is_loaded'] is True and reload is False: |
| 1213 | + _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") |
| 1214 | + else: |
| 1215 | + try: |
| 1216 | + _logger.info(f"Loading {row['name']} data") |
| 1217 | + row['load_func']() |
| 1218 | + self.data_info.loc[idx, 'is_loaded'] = True |
| 1219 | + except BaseException as e: |
| 1220 | + _logger.warning(f"Could not load {row['name']} data.") |
| 1221 | + _logger.debug(e) |
| 1222 | + |
| 1223 | + def load_trials(self): |
| 1224 | + """ |
| 1225 | + Function to load trials data into SessionLoader.trials |
| 1226 | + """ |
| 1227 | + self.trials = self.one.load_object(self.eid, 'trials', collection='alf').to_df() |
| 1228 | + self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True |
| 1229 | + |
| 1230 | + def load_wheel(self, sampling_rate=1000, smooth_size=0.03): |
| 1231 | + """ |
| 1232 | + Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position |
| 1233 | + is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which |
| 1234 | + Gaussian smoothing is applied. |
| 1235 | +
|
| 1236 | + Parameters |
| 1237 | + ---------- |
| 1238 | + sampling_rate: float |
| 1239 | + Rate at which to sample the wheel position, default is 1000 Hz |
| 1240 | + smooth_size: float |
| 1241 | + Size of Gaussian smoothing window in seconds, default is 0.03 |
| 1242 | + """ |
| 1243 | + wheel_raw = self.one.load_object(self.eid, 'wheel') |
| 1244 | + # TODO: Fix this instead of raising error? |
| 1245 | + if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: |
| 1246 | + raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") |
| 1247 | + # resample the wheel position and compute velocity, acceleration |
| 1248 | + self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) |
| 1249 | + self.wheel['position'], self.wheel['times'] = interpolate_position( |
| 1250 | + wheel_raw['timestamps'], wheel_raw['position'], freq=sampling_rate) |
| 1251 | + self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed( |
| 1252 | + self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size) |
| 1253 | + self.wheel = self.wheel.apply(np.float32) |
| 1254 | + self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True |
| 1255 | + |
| 1256 | + def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): |
| 1257 | + """ |
| 1258 | + Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a |
| 1259 | + dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas |
| 1260 | + Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. |
| 1261 | +
|
| 1262 | + Parameters |
| 1263 | + ---------- |
| 1264 | + likelihood_thr: float |
| 1265 | + The position of each tracked body part come with a likelihood of that estimate for each time point. |
| 1266 | + Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set |
| 1267 | + likelihood_thr=1. Default is 0.9 |
| 1268 | + views: list |
| 1269 | + List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} |
| 1270 | + """ |
| 1271 | + # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger |
| 1272 | + self.pose = {} |
| 1273 | + for view in views: |
| 1274 | + try: |
| 1275 | + pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times']) |
| 1276 | + # Double check if video timestamps are correct length or can be fixed |
| 1277 | + times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) |
| 1278 | + self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) |
| 1279 | + self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) |
| 1280 | + self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True |
| 1281 | + except BaseException as e: |
| 1282 | + _logger.warning(f'Could not load pose data for {view}Camera. Skipping camera.') |
| 1283 | + _logger.debug(e) |
| 1284 | + |
| 1285 | + def load_motion_energy(self, views=['left', 'right', 'body']): |
| 1286 | + """ |
| 1287 | + Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a |
| 1288 | + dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are |
| 1289 | + pandas Dataframes with the timestamps and motion energy data. |
| 1290 | + The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad |
| 1291 | + (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the |
| 1292 | + body (bodyMotionEnergy). |
| 1293 | +
|
| 1294 | + Parameters |
| 1295 | + ---------- |
| 1296 | + views: list |
| 1297 | + List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} |
| 1298 | + """ |
| 1299 | + names = {'left': 'whiskerMotionEnergy', |
| 1300 | + 'right': 'whiskerMotionEnergy', |
| 1301 | + 'body': 'bodyMotionEnergy'} |
| 1302 | + # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger |
| 1303 | + self.motion_energy = {} |
| 1304 | + for view in views: |
| 1305 | + try: |
| 1306 | + me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times']) |
| 1307 | + # Double check if video timestamps are correct length or can be fixed |
| 1308 | + times_fixed, motion_energy = self._check_video_timestamps( |
| 1309 | + view, me_raw['times'], me_raw['ROIMotionEnergy']) |
| 1310 | + self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) |
| 1311 | + self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) |
| 1312 | + self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True |
| 1313 | + except BaseException as e: |
| 1314 | + _logger.warning(f'Could not load motion energy data for {view}Camera. Skipping camera.') |
| 1315 | + _logger.debug(e) |
| 1316 | + |
| 1317 | + def load_licks(self): |
| 1318 | + """ |
| 1319 | + Not yet implemented |
| 1320 | + """ |
| 1321 | + pass |
| 1322 | + |
| 1323 | + def load_pupil(self, snr_thresh=5.): |
| 1324 | + """ |
| 1325 | + Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. |
| 1326 | +
|
| 1327 | + Parameters |
| 1328 | + ---------- |
| 1329 | + snr_thresh: float |
| 1330 | + An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data |
| 1331 | + will be considered unusable and will be discarded. |
| 1332 | + """ |
| 1333 | + # Try to load from features |
| 1334 | + feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features']) |
| 1335 | + if 'features' in feat_raw.keys(): |
| 1336 | + times_fixed, feats = self._check_video_timestamps(feat_raw['times'], feat_raw['features']) |
| 1337 | + self.pupil = feats.copy() |
| 1338 | + self.pupil.insert(0, 'times', times_fixed) |
| 1339 | + |
| 1340 | + # If unavailable compute on the fly |
| 1341 | + else: |
| 1342 | + _logger.info('Pupil diameter not available, trying to compute on the fly.') |
| 1343 | + if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] |
| 1344 | + and 'leftCamera' in self.pose.keys()): |
| 1345 | + # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt |
| 1346 | + copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data |
| 1347 | + self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 |
| 1348 | + dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable |
| 1349 | + self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place |
| 1350 | + else: |
| 1351 | + self.load_pose(views=['left'], likelihood_thr=0.9) |
| 1352 | + dlc_thr = self.pose['leftCamera'].copy() |
| 1353 | + |
| 1354 | + self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) |
| 1355 | + try: |
| 1356 | + self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') |
| 1357 | + except BaseException as e: |
| 1358 | + _logger.error("Computing smooth pupil diameter failed, saving all NaNs.") |
| 1359 | + _logger.debug(e) |
| 1360 | + self.pupil['pupilDiameter_smooth'] = np.nan |
| 1361 | + |
| 1362 | + if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): |
| 1363 | + good_idxs = np.where( |
| 1364 | + ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] |
| 1365 | + snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / |
| 1366 | + (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) |
| 1367 | + if snr < snr_thresh: |
| 1368 | + self.pupil = pd.DataFrame() |
| 1369 | + _logger.error(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') |
| 1370 | + |
| 1371 | + def _check_video_timestamps(self, view, video_timestamps, video_data): |
| 1372 | + """ |
| 1373 | + Helper function to check for the length of the video frames vs video timestamps and fix in case |
| 1374 | + timestamps are longer than video frames. |
| 1375 | + """ |
| 1376 | + # If camera times are shorter than video data, or empty, no current fix |
| 1377 | + if video_timestamps.shape[0] < video_data.shape[0]: |
| 1378 | + if video_timestamps.shape[0] == 0: |
| 1379 | + msg = f'Camera times empty for {view}Camera.' |
| 1380 | + else: |
| 1381 | + msg = f'Camera times are shorter than video data for {view}Camera.' |
| 1382 | + _logger.warning(msg) |
| 1383 | + raise ValueError(msg) |
| 1384 | + # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. |
| 1385 | + # This is because the first few frames are sometimes not recorded. We can remove the first few |
| 1386 | + # timestamps in this case |
| 1387 | + elif video_timestamps.shape[0] > video_data.shape[0]: |
| 1388 | + video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] |
| 1389 | + return video_timestamps_fixed, video_data |
| 1390 | + else: |
| 1391 | + return video_timestamps, video_data |
0 commit comments