Skip to content

Commit d05a60f

Browse files
authored
Merge pull request #503 from int-brain-lab/session_loader
Session loader
2 parents affbaf8 + d4e2c2d commit d05a60f

File tree

1 file changed

+334
-1
lines changed

1 file changed

+334
-1
lines changed

brainbox/io/one.py

Lines changed: 334 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
import logging
44
import os
55
from pathlib import Path
@@ -24,6 +24,8 @@
2424
from brainbox.core import TimeSeries
2525
from brainbox.processing import sync
2626
from brainbox.metrics.single_units import quick_unit_metrics
27+
from brainbox.behavior.wheel import interpolate_position, velocity_smoothed
28+
from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
2729

2830
_logger = logging.getLogger('ibllib')
2931

@@ -1056,3 +1058,334 @@ def samples2times(self, values, direction='forward'):
10561058
'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
10571059
}
10581060
return self._sync[direction](values)
1061+
1062+
1063+
@dataclass
1064+
class SessionLoader:
1065+
"""
1066+
Object to load session data for a give session in the recommended way.
1067+
1068+
Parameters
1069+
----------
1070+
one: one.api.ONE instance
1071+
Can be in remote or local mode (required)
1072+
session_path: string or pathlib.Path
1073+
The absolute path to the session (one of session_path or eid is required)
1074+
eid: string
1075+
database UUID of the session (one of session_path or eid is required)
1076+
1077+
If both are provided, session_path takes precedence over eid.
1078+
1079+
Examples
1080+
--------
1081+
1) Load all available session data for one session:
1082+
>>> from one.api import ONE
1083+
>>> from brainbox.io.one import SessionLoader
1084+
>>> one = ONE()
1085+
>>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
1086+
# Object is initiated, but no data is loaded as you can see in the data_info attribute
1087+
>>> sess_loader.data_info
1088+
name is_loaded
1089+
0 trials False
1090+
1 wheel False
1091+
2 pose False
1092+
3 motion_energy False
1093+
4 pupil False
1094+
1095+
# Loading all available session data, the data_info attribute now shows which data has been loaded
1096+
>>> sess_loader.load_session_data()
1097+
>>> sess_loader.data_info
1098+
name is_loaded
1099+
0 trials True
1100+
1 wheel True
1101+
2 pose True
1102+
3 motion_energy True
1103+
4 pupil False
1104+
1105+
# The data is loaded in pandas dataframes that you can access via the respective attributes, e.g.
1106+
>>> type(sess_loader.trials)
1107+
pandas.core.frame.DataFrame
1108+
>>> sess_loader.trials.shape
1109+
(626, 18)
1110+
# Each data comes with its own timestamps in a column called 'times'
1111+
>>> sess_loader.wheel['times']
1112+
0 0.134286
1113+
1 0.135286
1114+
2 0.136286
1115+
3 0.137286
1116+
4 0.138286
1117+
...
1118+
# For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera.
1119+
# The dataframes of all cameras are collected in a dictionary
1120+
>>> type(sess_loader.pose)
1121+
dict
1122+
>>> sess_loader.pose.keys()
1123+
dict_keys(['leftCamera', 'rightCamera', 'bodyCamera'])
1124+
>>> sess_loader.pose['bodyCamera'].columns
1125+
Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object')
1126+
# In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
1127+
functions:
1128+
>>> sess_loader.load_wheel(sampling_rate=100)
1129+
"""
1130+
one: One = None
1131+
session_path: Path = ''
1132+
eid: str = ''
1133+
data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1134+
trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1135+
wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1136+
pose: dict = field(default_factory=dict, repr=False)
1137+
motion_energy: dict = field(default_factory=dict, repr=False)
1138+
pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1139+
1140+
def __post_init__(self):
1141+
"""
1142+
Function that runs automatically after initiation of the dataclass attributes.
1143+
Checks for required inputs, sets session_path and eid, creates data_info table.
1144+
"""
1145+
if self.one is None:
1146+
raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
1147+
"a fully local instance of One.")
1148+
# If session path is given, takes precedence over eid
1149+
if self.session_path is not None and self.session_path != '':
1150+
self.eid = self.one.to_eid(self.session_path)
1151+
self.session_path = Path(self.session_path)
1152+
# Providing no session path, try to infer from eid
1153+
else:
1154+
if self.eid is not None and self.eid != '':
1155+
self.session_path = self.one.eid2path(self.eid)
1156+
else:
1157+
raise ValueError("If no session path is given, eid is required.")
1158+
1159+
data_names = [
1160+
'trials',
1161+
'wheel',
1162+
'pose',
1163+
'motion_energy',
1164+
'pupil'
1165+
]
1166+
self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names)))
1167+
1168+
def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False):
1169+
"""
1170+
Function to load available session data into the SessionLoader object. Input parameters allow to control which
1171+
data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
1172+
parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
1173+
in SessionLoader.data_info
1174+
1175+
Parameters
1176+
----------
1177+
trials: boolean
1178+
Whether to load all trials data into SessionLoader.trials, default is True
1179+
wheel: boolean
1180+
Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
1181+
pose: boolean
1182+
Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose,
1183+
default is True
1184+
motion_energy: boolean
1185+
Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
1186+
into SessionLoader.motion_energy, default is True
1187+
pupil: boolean
1188+
Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
1189+
default is True
1190+
reload: boolean
1191+
Whether to reload data that has already been loaded into this SessionLoader object, default is False
1192+
"""
1193+
load_df = self.data_info.copy()
1194+
load_df['to_load'] = [
1195+
trials,
1196+
wheel,
1197+
pose,
1198+
motion_energy,
1199+
pupil
1200+
]
1201+
load_df['load_func'] = [
1202+
self.load_trials,
1203+
self.load_wheel,
1204+
self.load_pose,
1205+
self.load_motion_energy,
1206+
self.load_pupil
1207+
]
1208+
1209+
for idx, row in load_df.iterrows():
1210+
if row['to_load'] is False:
1211+
_logger.debug(f"Not loading {row['name']} data, set to False.")
1212+
elif row['is_loaded'] is True and reload is False:
1213+
_logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.")
1214+
else:
1215+
try:
1216+
_logger.info(f"Loading {row['name']} data")
1217+
row['load_func']()
1218+
self.data_info.loc[idx, 'is_loaded'] = True
1219+
except BaseException as e:
1220+
_logger.warning(f"Could not load {row['name']} data.")
1221+
_logger.debug(e)
1222+
1223+
def load_trials(self):
1224+
"""
1225+
Function to load trials data into SessionLoader.trials
1226+
"""
1227+
self.trials = self.one.load_object(self.eid, 'trials', collection='alf').to_df()
1228+
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True
1229+
1230+
def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
1231+
"""
1232+
Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
1233+
is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
1234+
Gaussian smoothing is applied.
1235+
1236+
Parameters
1237+
----------
1238+
sampling_rate: float
1239+
Rate at which to sample the wheel position, default is 1000 Hz
1240+
smooth_size: float
1241+
Size of Gaussian smoothing window in seconds, default is 0.03
1242+
"""
1243+
wheel_raw = self.one.load_object(self.eid, 'wheel')
1244+
# TODO: Fix this instead of raising error?
1245+
if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]:
1246+
raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
1247+
# resample the wheel position and compute velocity, acceleration
1248+
self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration'])
1249+
self.wheel['position'], self.wheel['times'] = interpolate_position(
1250+
wheel_raw['timestamps'], wheel_raw['position'], freq=sampling_rate)
1251+
self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed(
1252+
self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size)
1253+
self.wheel = self.wheel.apply(np.float32)
1254+
self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True
1255+
1256+
def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1257+
"""
1258+
Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
1259+
dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
1260+
Dataframes with the timestamps and pose data, one row for each body part tracked for that camera.
1261+
1262+
Parameters
1263+
----------
1264+
likelihood_thr: float
1265+
The position of each tracked body part come with a likelihood of that estimate for each time point.
1266+
Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set
1267+
likelihood_thr=1. Default is 0.9
1268+
views: list
1269+
List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1270+
"""
1271+
# empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1272+
self.pose = {}
1273+
for view in views:
1274+
try:
1275+
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
1276+
# Double check if video timestamps are correct length or can be fixed
1277+
times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc'])
1278+
self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr)
1279+
self.pose[f'{view}Camera'].insert(0, 'times', times_fixed)
1280+
self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True
1281+
except BaseException as e:
1282+
_logger.warning(f'Could not load pose data for {view}Camera. Skipping camera.')
1283+
_logger.debug(e)
1284+
1285+
def load_motion_energy(self, views=['left', 'right', 'body']):
1286+
"""
1287+
Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a
1288+
dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are
1289+
pandas Dataframes with the timestamps and motion energy data.
1290+
The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad
1291+
(whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the
1292+
body (bodyMotionEnergy).
1293+
1294+
Parameters
1295+
----------
1296+
views: list
1297+
List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1298+
"""
1299+
names = {'left': 'whiskerMotionEnergy',
1300+
'right': 'whiskerMotionEnergy',
1301+
'body': 'bodyMotionEnergy'}
1302+
# empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1303+
self.motion_energy = {}
1304+
for view in views:
1305+
try:
1306+
me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'])
1307+
# Double check if video timestamps are correct length or can be fixed
1308+
times_fixed, motion_energy = self._check_video_timestamps(
1309+
view, me_raw['times'], me_raw['ROIMotionEnergy'])
1310+
self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy)
1311+
self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed)
1312+
self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True
1313+
except BaseException as e:
1314+
_logger.warning(f'Could not load motion energy data for {view}Camera. Skipping camera.')
1315+
_logger.debug(e)
1316+
1317+
def load_licks(self):
1318+
"""
1319+
Not yet implemented
1320+
"""
1321+
pass
1322+
1323+
def load_pupil(self, snr_thresh=5.):
1324+
"""
1325+
Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil.
1326+
1327+
Parameters
1328+
----------
1329+
snr_thresh: float
1330+
An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data
1331+
will be considered unusable and will be discarded.
1332+
"""
1333+
# Try to load from features
1334+
feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'])
1335+
if 'features' in feat_raw.keys():
1336+
times_fixed, feats = self._check_video_timestamps(feat_raw['times'], feat_raw['features'])
1337+
self.pupil = feats.copy()
1338+
self.pupil.insert(0, 'times', times_fixed)
1339+
1340+
# If unavailable compute on the fly
1341+
else:
1342+
_logger.info('Pupil diameter not available, trying to compute on the fly.')
1343+
if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0]
1344+
and 'leftCamera' in self.pose.keys()):
1345+
# If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1346+
copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data
1347+
self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9
1348+
dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable
1349+
self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place
1350+
else:
1351+
self.load_pose(views=['left'], likelihood_thr=0.9)
1352+
dlc_thr = self.pose['leftCamera'].copy()
1353+
1354+
self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr)
1355+
try:
1356+
self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left')
1357+
except BaseException as e:
1358+
_logger.error("Computing smooth pupil diameter failed, saving all NaNs.")
1359+
_logger.debug(e)
1360+
self.pupil['pupilDiameter_smooth'] = np.nan
1361+
1362+
if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])):
1363+
good_idxs = np.where(
1364+
~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
1365+
snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) /
1366+
(np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
1367+
if snr < snr_thresh:
1368+
self.pupil = pd.DataFrame()
1369+
_logger.error(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.')
1370+
1371+
def _check_video_timestamps(self, view, video_timestamps, video_data):
1372+
"""
1373+
Helper function to check for the length of the video frames vs video timestamps and fix in case
1374+
timestamps are longer than video frames.
1375+
"""
1376+
# If camera times are shorter than video data, or empty, no current fix
1377+
if video_timestamps.shape[0] < video_data.shape[0]:
1378+
if video_timestamps.shape[0] == 0:
1379+
msg = f'Camera times empty for {view}Camera.'
1380+
else:
1381+
msg = f'Camera times are shorter than video data for {view}Camera.'
1382+
_logger.warning(msg)
1383+
raise ValueError(msg)
1384+
# For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
1385+
# This is because the first few frames are sometimes not recorded. We can remove the first few
1386+
# timestamps in this case
1387+
elif video_timestamps.shape[0] > video_data.shape[0]:
1388+
video_timestamps_fixed = video_timestamps[-video_data.shape[0]:]
1389+
return video_timestamps_fixed, video_data
1390+
else:
1391+
return video_timestamps, video_data

0 commit comments

Comments
 (0)