Skip to content

Commit ef3f4ab

Browse files
committed
pupil WIP
1 parent 9e0cdbc commit ef3f4ab

File tree

1 file changed

+86
-18
lines changed

1 file changed

+86
-18
lines changed

brainbox/io/one.py

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from brainbox.processing import sync
2626
from brainbox.metrics.single_units import quick_unit_metrics
2727
from brainbox.behavior.wheel import interpolate_position, velocity_smoothed
28-
from brainbox.behavior.dlc import likelihood_threshold
28+
from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter
2929

3030
_logger = logging.getLogger('ibllib')
3131

@@ -1065,10 +1065,12 @@ class SessionLoader:
10651065
one: One = None
10661066
eid: str = ''
10671067
session_path: Path = ''
1068+
data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10681069
trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10691070
wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
1070-
poses: dict = field(default_factory=dict, repr=False)
1071+
pose: dict = field(default_factory=dict, repr=False)
10711072
motion_energy: dict = field(default_factory=dict, repr=False)
1073+
pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
10721074

10731075
def __post_init__(self):
10741076
# Providing no session path, eid and one are required
@@ -1089,23 +1091,51 @@ def __post_init__(self):
10891091
self.one._cache['sessions'] = df_sessions.set_index('id')
10901092
self.one._cache['datasets'] = cache._make_datasets_df(self.session_path, hash_files=False)
10911093
self.eid = str(self.session_path.relative_to(self.session_path.parents[2]))
1092-
1093-
def load_session_data(self, trials=True, wheel=True, poses=True, motion_energy=True):
1094-
# TODO: Dont reload when data already loaded?
1095-
names = ['trials', 'wheel', 'poses', 'motion_energy']
1096-
args = [trials, wheel, poses, motion_energy]
1097-
loading_funcs = [self.load_trials, self.load_wheel, self.load_poses, self.load_motion_energy]
1098-
1099-
for name, arg, loading_func in zip(names, args, loading_funcs):
1100-
if arg is True:
1094+
# Information of data that is and can be loaded
1095+
data_names = [
1096+
'trials',
1097+
'wheel',
1098+
'poses',
1099+
'motion_energy',
1100+
'pupil'
1101+
]
1102+
self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False]*len(data_names)))
1103+
1104+
def load_session_data(self, trials=True, wheel=True, poses=True, motion_energy=True, pupil=True, reload=False):
1105+
1106+
load_df = self.data_info.copy()
1107+
load_df['to_load'] = [
1108+
trials,
1109+
wheel,
1110+
poses,
1111+
motion_energy,
1112+
pupil
1113+
]
1114+
load_df['load_func'] = [
1115+
self.load_trials,
1116+
self.load_wheel,
1117+
self.load_pose,
1118+
self.load_motion_energy,
1119+
self.load_pupil
1120+
]
1121+
1122+
for idx, row in load_df.iterrows():
1123+
if row['to_load'] is False:
1124+
_logger.debug(f"Not loading {row['name']} data, set to False.")
1125+
elif row['is_loaded'] is True and reload is False:
1126+
_logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.")
1127+
else:
11011128
try:
1102-
loading_func()
1129+
_logger.info(f"Loading {row['name']} data")
1130+
row['load_func']()
1131+
self.data_info.loc[idx, 'is_loaded'] = True
11031132
except BaseException as e:
1104-
_logger.warning(f"Could not load {name} data.")
1133+
_logger.warning(f"Could not load {row['name']} data.")
11051134
_logger.debug(e)
11061135

11071136
def load_trials(self):
11081137
self.trials = self.one.load_object(self.eid, 'trials').to_df()
1138+
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True
11091139

11101140
def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
11111141
wheel_raw = self.one.load_object(self.eid, 'wheel')
@@ -1118,15 +1148,17 @@ def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
11181148
wheel_raw['timestamps'], wheel_raw['position'], freq=sampling_rate)
11191149
self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed(
11201150
self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size)
1151+
self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True
11211152

1122-
def load_poses(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1153+
def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
11231154
for view in views:
11241155
try:
11251156
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
11261157
# Double check if video timestamps are correct length or can be fixed
11271158
times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc'])
1128-
self.poses[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr)
1129-
self.poses[f'{view}Camera'].insert(0, 'times', times_fixed)
1159+
self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr)
1160+
self.pose[f'{view}Camera'].insert(0, 'times', times_fixed)
1161+
self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True
11301162
except BaseException as e:
11311163
_logger.error(f'Could not load pose data for {view}Camera. Skipping camera.')
11321164
_logger.debug(e)
@@ -1143,15 +1175,51 @@ def load_motion_energy(self, views=['left', 'right', 'body']):
11431175
view, me_raw['times'], me_raw['ROIMotionEnergy'])
11441176
self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy)
11451177
self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed)
1178+
self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True
11461179
except BaseException as e:
11471180
_logger.error(f'Could not load motion energy data for {view}Camera. Skipping camera.')
11481181
_logger.debug(e)
11491182

11501183
def load_licks(self):
11511184
pass
11521185

1153-
def load_pupil_diameter(self):
1154-
pass
1186+
def load_pupil(self, snr_thresh=5):
1187+
# Try to load from features
1188+
feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'])
1189+
if 'features' in feat_raw.keys():
1190+
times_fixed, feats = self._check_video_timestamps(feat_raw['times'], feat_raw['features'])
1191+
self.pupil = feats.copy()
1192+
self.pupil.insert(0, 'times', times_fixed)
1193+
1194+
# If unavailable compute on the fly
1195+
else:
1196+
_logger.info('Pupil diameter not available, trying to compute on the fly.')
1197+
if self.data_info[self.data_info['name'] == 'pose', 'is_loaded'] and 'leftCamera' in self.pose.keys():
1198+
# If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
1199+
copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data
1200+
self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9
1201+
dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable
1202+
self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place
1203+
else:
1204+
self.load_pose(views=['left'], likelihood_thr=0.9)
1205+
dlc_thr = self.pose['leftCamera'].copy()
1206+
1207+
self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr)
1208+
try:
1209+
self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left')
1210+
except BaseException as e:
1211+
_logger.error(f"Computing smooth pupil diameter failed, saving all NaNs.")
1212+
_logger.debug(e)
1213+
self.pupil['pupilDiameter_smooth'] = np.nan
1214+
1215+
if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])):
1216+
good_idxs = np.where(
1217+
~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
1218+
snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) /
1219+
(np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
1220+
if snr < snr_thresh:
1221+
self.pupil = pd.DataFrame
1222+
raise ValueError(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.')
11551223

11561224
def align_trials_to_event(self, align_event='stimOn_times', pre_event=0.5, post_event=0.5):
11571225
possible_events = ['stimOn_times', 'goCue_times', 'goCueTrigger_times',

0 commit comments

Comments
 (0)